summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-12 22:38:57 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-13 10:56:07 -0500
commite653d306ce90e5612796d8adce9eb34b1c10e85a (patch)
treeee83cef73ea9e8c6206ee37f246ffdddfd2d1a8c /mlir
parentb4a99a061f517e60985667e39519f60186cbb469 (diff)
downloadllvm-e653d306ce90e5612796d8adce9eb34b1c10e85a.tar.gz
[mlir][Linalg] Update ReshapeOp::build to be more idiomatic
Summary: This diff makes it easier to create a `linalg.reshape` op and adds an EDSC builder api test to exercise the new builders. Reviewers: ftynse, jpienaar Subscribers: mehdi_amini, rriddle, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72580
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h2
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td14
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp48
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp27
4 files changed, 82 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
index 42b286d504f6..7777f5ceebf2 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
@@ -17,7 +17,7 @@ namespace edsc {
namespace intrinsics {
using linalg_fill = OperationBuilder<linalg::FillOp>;
-using linalg_reshape = OperationBuilder<linalg::ReshapeOp>;
+using linalg_reshape = ValueBuilder<linalg::ReshapeOp>;
using linalg_yield = OperationBuilder<linalg::YieldOp>;
} // namespace intrinsics
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 8e1f40f3620b..db15f516adc8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -100,9 +100,17 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
```
}];
- let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value view, "
- "ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">];
+ let builders = [
+ // Builder for a contracting reshape whose result type is computed from
+ // `view` and `reassociation`.
+ OpBuilder<"Builder *b, OperationState &result, Value view, "
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Builder for a reshape whose result type is passed explicitly. This may be
+ // either a contracting or expanding reshape.
+ OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">];
let extraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cf27a817edb1..e244542f9b49 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -465,14 +465,52 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
}
-void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result,
- Value view, ArrayAttr reassociation,
- ArrayRef<NamedAttribute> attrs) {
- auto maps = getAffineMaps(reassociation);
+template <typename AffineExprTy>
+unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
+ unsigned pos = 0;
+ for (auto exprs : exprArrays) {
+ for (auto expr : exprs) {
+ expr.walk([&pos](AffineExpr e) {
+ if (auto d = e.dyn_cast<AffineExprTy>())
+ pos = std::max(pos, d.getPosition());
+ });
+ }
+ }
+ return pos;
+}
+
+static SmallVector<AffineMap, 4>
+getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
+ unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
+ unsigned maxSym = getMaxPosOfType<AffineSymbolExpr>(reassociation);
+ assert(maxSym == 0 && "Expected symbol-less expressions");
+ SmallVector<AffineMap, 4> maps;
+ maps.reserve(reassociation.size());
+ for (auto exprs : reassociation)
+ maps.push_back(AffineMap::get(maxDim + 1, 0, exprs));
+ return maps;
+}
+
+void mlir::linalg::ReshapeOp::build(
+ Builder *b, OperationState &result, Value view,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
auto memRefType = view.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(memRefType, maps);
build(b, result, resultType, view, attrs);
- result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation);
+ result.addAttribute(ReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::ReshapeOp::build(
+ Builder *b, OperationState &result, Type resultType, Value view,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
+ build(b, result, resultType, view, attrs);
+ result.addAttribute(ReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
}
static void print(OpAsmPrinter &p, ReshapeOp op) {
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index fcd5e37e4ef6..7ddfe50130c8 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/EDSC/Builders.h"
@@ -962,6 +963,32 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
f.erase();
}
+// clang-format off
+// CHECK-LABEL: func @linalg_metadata_ops
+// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32>
+// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32>
+// clang-format on
+TEST_FUNC(linalg_metadata_ops) {
+ using namespace edsc;
+ using namespace edsc::intrinsics;
+
+ auto f32Type = FloatType::getF32(&globalContext());
+ auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
+ auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
+
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ AffineExpr i, j, k;
+ bindDims(&globalContext(), i, j, k);
+ ValueHandle v(f.getArgument(0));
+ auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+ linalg_reshape(memrefType, reshaped,
+ ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+
+ f.print(llvm::outs());
+ f.erase();
+}
+
int main() {
RUN_TESTS();
return 0;