diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 16 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 16 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 23 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 |
4 files changed, 34 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index c8c9c0bd4af8..e6fce56d4140 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op, offsetsCopy[dimension] = offset; // Create the part as it it were a single tile. - SmallVector<Operation *> tiled = + FailureOr<TilingResult> tilingResult = op.getTiledImplementation(b, offsetsCopy, sizesCopy); - assert(tiled.size() == 1 && "expected a single result from tiling"); - auto part = cast<TilingInterface>(tiled.front()); // Insert the results back and populate the `results` list. - for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) { + for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { SmallVector<OpFoldResult> resultOffsets, resultSizes; - if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, + if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, resultOffsets, resultSizes))) return nullptr; SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), b.getIndexAttr(1)); Value inserted = b.create<tensor::InsertSliceOp>( - loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, + loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); } - - return part; + // TODO: this part can be generalized maybe to not expect a single op. + assert(tilingResult->tiledOps.size() == 1 && + "expected split part to return a single tiled operation"); + return cast<TilingInterface>(tilingResult->tiledOps[0]); } std::pair<TilingInterface, TilingInterface> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 62eef97a1744..1e404cabbb51 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl( } // 4. Tile the cloned op and delete the clone. - SmallVector<Operation *> tiledOps = + FailureOr<TilingResult> tilingResult = cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets, tiledSizes); b.eraseOp(clonedOp); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); } // 5. Parallel insert back into the result tensor. @@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( // 5. Tile the cloned op and delete the clone. if (tileSizes.empty()) { - SmallVector<Operation *> tiledOps = + FailureOr<TilingResult> tilingResult = cast<TilingInterface>(clonedOp).getTiledImplementation( b, tiledOffsets, tiledSizes); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - tilingResults = tiledOp->getResults(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; } else { LinalgTilingOptions options; FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index cfc27ca44e42..676d6330cde3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -111,7 +111,7 @@ struct LinalgOpTilingInterface } // Instantiate the tiled implementation of the operation. - SmallVector<Operation *> + FailureOr<TilingResult> getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { @@ -129,7 +129,7 @@ struct LinalgOpTilingInterface Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); - return {tiledOp}; + return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())}; } // Return the details of the output tile generated by the tiled @@ -160,10 +160,10 @@ struct LinalgOpTilingInterface return success(); } - FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes) const { + FailureOr<TilingResult> + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { auto linalgOp = cast<LinalgOp>(op); // Check that the indexing map used for the output is a projected @@ -197,12 +197,15 @@ struct LinalgOpTilingInterface iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; } - SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation( - b, iterationTileOffsets, iterationTileSizes); - if (tiledOp.size() != 1) + FailureOr<TilingResult> tilingResult = + tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, + iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); - return tiledOp[0]->getResult(resultNumber); + return TilingResult{ + tilingResult->tiledOps, + SmallVector<Value>{tilingResult->tiledValues[resultNumber]}}; } LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 17c46182eb5d..e001f59b21e9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( return failure(); } - Operation *tiledPadOp = + FailureOr<TilingResult> tilingResult = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); + if (failed(tilingResult)) + return failure(); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); + rewriter.replaceOp(sliceOp, tilingResult->tiledValues); return success(); } |