diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-01-12 22:38:57 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-01-13 10:56:07 -0500 |
commit | e653d306ce90e5612796d8adce9eb34b1c10e85a (patch) | |
tree | ee83cef73ea9e8c6206ee37f246ffdddfd2d1a8c /mlir | |
parent | b4a99a061f517e60985667e39519f60186cbb469 (diff) | |
download | llvm-e653d306ce90e5612796d8adce9eb34b1c10e85a.tar.gz |
[mlir][Linalg] Update ReshapeOp::build to be more idiomatic
Summary:
This diff makes it easier to create a `linalg.reshape` op
and adds an EDSC builder api test to exercise the new builders.
Reviewers: ftynse, jpienaar
Subscribers: mehdi_amini, rriddle, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72580
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 14 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 48 | ||||
-rw-r--r-- | mlir/test/EDSC/builder-api-test.cpp | 27 |
4 files changed, 82 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h index 42b286d504f6..7777f5ceebf2 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -17,7 +17,7 @@ namespace edsc { namespace intrinsics { using linalg_fill = OperationBuilder<linalg::FillOp>; -using linalg_reshape = OperationBuilder<linalg::ReshapeOp>; +using linalg_reshape = ValueBuilder<linalg::ReshapeOp>; using linalg_yield = OperationBuilder<linalg::YieldOp>; } // namespace intrinsics diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 8e1f40f3620b..db15f516adc8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -100,9 +100,17 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, ``` }]; - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value view, " - "ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">]; + let builders = [ + // Builder for a contracting reshape whose result type is computed from + // `view` and `reassociation`. + OpBuilder<"Builder *b, OperationState &result, Value view, " + "ArrayRef<ArrayRef<AffineExpr>> reassociation, " + "ArrayRef<NamedAttribute> attrs = {}">, + // Builder for a reshape whose result type is passed explicitly. This may be + // either a contracting or expanding reshape. + OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view," + "ArrayRef<ArrayRef<AffineExpr>> reassociation, " + "ArrayRef<NamedAttribute> attrs = {}">]; let extraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cf27a817edb1..e244542f9b49 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -465,14 +465,52 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) { [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs); } -void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result, - Value view, ArrayAttr reassociation, - ArrayRef<NamedAttribute> attrs) { - auto maps = getAffineMaps(reassociation); +template <typename AffineExprTy> +unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) { + unsigned pos = 0; + for (auto exprs : exprArrays) { + for (auto expr : exprs) { + expr.walk([&pos](AffineExpr e) { + if (auto d = e.dyn_cast<AffineExprTy>()) + pos = std::max(pos, d.getPosition()); + }); + } + } + return pos; +} + +static SmallVector<AffineMap, 4> +getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) { + unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); + unsigned maxSym = getMaxPosOfType<AffineSymbolExpr>(reassociation); + assert(maxSym == 0 && "Expected symbol-less expressions"); + SmallVector<AffineMap, 4> maps; + maps.reserve(reassociation.size()); + for (auto exprs : reassociation) + maps.push_back(AffineMap::get(maxDim + 1, 0, exprs)); + return maps; +} + +void mlir::linalg::ReshapeOp::build( + Builder *b, OperationState &result, Value view, + ArrayRef<ArrayRef<AffineExpr>> reassociation, + ArrayRef<NamedAttribute> attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); auto memRefType = view.getType().cast<MemRefType>(); auto resultType = computeReshapeCollapsedType(memRefType, maps); build(b, result, resultType, view, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation); + result.addAttribute(ReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); +} + +void mlir::linalg::ReshapeOp::build( + Builder *b, OperationState &result, Type resultType, Value view, + ArrayRef<ArrayRef<AffineExpr>> reassociation, + ArrayRef<NamedAttribute> attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); + build(b, result, resultType, view, attrs); + result.addAttribute(ReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); } static void print(OpAsmPrinter &p, ReshapeOp op) { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index fcd5e37e4ef6..7ddfe50130c8 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" @@ -962,6 +963,32 @@ TEST_FUNC(linalg_dilated_conv_nhwc) { f.erase(); } +// clang-format off +// CHECK-LABEL: func @linalg_metadata_ops +// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32> +// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32> +// clang-format on +TEST_FUNC(linalg_metadata_ops) { + using namespace edsc; + using namespace edsc::intrinsics; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0); + auto f = makeFunction("linalg_metadata_ops", {}, {memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + AffineExpr i, j, k; + bindDims(&globalContext(), i, j, k); + ValueHandle v(f.getArgument(0)); + auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k}); + linalg_reshape(memrefType, reshaped, + ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k}); + + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0; |