summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Split.cpp16
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp16
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp6
4 files changed, 34 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index c8c9c0bd4af8..e6fce56d4140 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
offsetsCopy[dimension] = offset;
// Create the part as it it were a single tile.
- SmallVector<Operation *> tiled =
+ FailureOr<TilingResult> tilingResult =
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
- assert(tiled.size() == 1 && "expected a single result from tiling");
- auto part = cast<TilingInterface>(tiled.front());
// Insert the results back and populate the `results` list.
- for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
+ for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
- if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
+ if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
resultOffsets, resultSizes)))
return nullptr;
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
b.getIndexAttr(1));
Value inserted = b.create<tensor::InsertSliceOp>(
- loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
+ loc, result, resultOperands[index], resultOffsets, resultSizes,
resultStrides);
results.push_back(inserted);
}
-
- return part;
+ // TODO: this part can be generalized maybe to not expect a single op.
+ assert(tilingResult->tiledOps.size() == 1 &&
+ "expected split part to return a single tiled operation");
+ return cast<TilingInterface>(tilingResult->tiledOps[0]);
}
std::pair<TilingInterface, TilingInterface>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 62eef97a1744..1e404cabbb51 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
}
// 4. Tile the cloned op and delete the clone.
- SmallVector<Operation *> tiledOps =
+ FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
tiledSizes);
b.eraseOp(clonedOp);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
+ assert(tilingResult->tiledOps.size() == 1 &&
+ "expected a single produced tiled op");
+ tiledOp = tilingResult->tiledOps.front();
}
// 5. Parallel insert back into the result tensor.
@@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 5. Tile the cloned op and delete the clone.
if (tileSizes.empty()) {
- SmallVector<Operation *> tiledOps =
+ FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(
b, tiledOffsets, tiledSizes);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
- tilingResults = tiledOp->getResults();
+ assert(tilingResult->tiledOps.size() == 1 &&
+ "expected a single produced tiled op");
+ tiledOp = tilingResult->tiledOps.front();
+ tilingResults = tilingResult->tiledValues;
} else {
LinalgTilingOptions options;
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index cfc27ca44e42..676d6330cde3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -111,7 +111,7 @@ struct LinalgOpTilingInterface
}
// Instantiate the tiled implementation of the operation.
- SmallVector<Operation *>
+ FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
@@ -129,7 +129,7 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return {tiledOp};
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
// Return the details of the output tile generated by the tiled
@@ -160,10 +160,10 @@ struct LinalgOpTilingInterface
return success();
}
- FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
- unsigned resultNumber,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult>
+ generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
@@ -197,12 +197,15 @@ struct LinalgOpTilingInterface
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}
- SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
- b, iterationTileOffsets, iterationTileSizes);
- if (tiledOp.size() != 1)
+ FailureOr<TilingResult> tilingResult =
+ tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
+ iterationTileSizes);
+ if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
- return tiledOp[0]->getResult(resultNumber);
+ return TilingResult{
+ tilingResult->tiledOps,
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
}
LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 17c46182eb5d..e001f59b21e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return failure();
}
- Operation *tiledPadOp =
+ FailureOr<TilingResult> tilingResult =
tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), zeroSliceGuard);
+ if (failed(tilingResult))
+ return failure();
// All shapes are static and the data source is actually used. Rewrite into
// pad(extract_slice(x)).
- rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
+ rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
return success();
}