diff options
author | Mahesh Ravishankar <ravishankarm@google.com> | 2023-03-01 16:33:14 -0800 |
---|---|---|
committer | Mahesh Ravishankar <ravishankarm@google.com> | 2023-03-16 14:29:03 +0000 |
commit | 809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365 (patch) | |
tree | 60ea5dbd90671206b2bd5b5d20fc6d0ea89db85c /mlir/lib | |
parent | a586c551000bcd874852ea1265f6dac4b3d894b3 (diff) | |
download | llvm-809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365.tar.gz |
[mlir][TilingInterface] Modify `TilingInterface` methods to better return the state of the transformed IR.
Currently the `getTiledImplementation` and `generateResultTileValue`
return just `SmallVector<Operation *>` and `FailureOr<Value>`.
- For `getTiledImplementation` returning empty implies tiling wasnt
done. There is also an implicit assumption that the tiled operation
results correspond to the tiled values of the result of the original
operation. This cannot handle cases where the tiled implementation
might use multiple operations to compute the tiled value for the
results of the untiled operation. Sometimes, the tiled operation
might not directly give the tiled values, and might require casts,
etc to get a replacement.
- For `generateResultTileValue`, it is assumed that the op defining
the returned `Value` is the operation that represents the tiled
computation. Again presence of casts, etc violate this.
Instead make these methods return
```
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
};
```
The `tiledOps` represent the operations generated that are relevant
for subsequent transformations. The `tiledValues` represent the tiled
values for the results of the original operation. This better
transmits the state of the transformed IR.
As a consequence the following methods also return `FailureOr<TilingResult>`
- `tensor::replaceExtractSliceWithTiledProducer`
- `tensor::bubbleUpPadSlice`
Differential Revision: https://reviews.llvm.org/D145133
Diffstat (limited to 'mlir/lib')
8 files changed, 137 insertions, 115 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4aae6458ff12..4503d451a405 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -431,16 +431,15 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder, /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, - Diagnostic &diag, - Operation *producerOp, - Operation *containingOp) { +static SmallVector<Operation *> +tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast<TilingInterface>(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the producer slices accessed within the containing operation. @@ -455,7 +454,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, if (it == tileableProducer->getUsers().end()) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; - return nullptr; + return {}; } auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it); @@ -468,27 +467,29 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, sliceOpToTile.getSource().cast<OpResult>().getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); - FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue( - rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), - sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + FailureOr<TilingResult> tileAndFuseResult = + tileableProducer.generateResultTileValue(rewriter, resultNumber, + sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; + } + for (auto tiledOp : tileAndFuseResult->tiledOps) { + LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast<RankedTensorType>() .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); - return fusedOp; + return tileAndFuseResult->tiledOps; } /// First, find the first "scf::ForallOp" user of `producerOp` and ensure @@ -497,7 +498,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, /// right before its "extract" use. The tiled op is fused under the /// `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( +static SmallVector<Operation *> +tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); @@ -506,7 +508,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the first use by a "scf::ForallOp" user. @@ -520,7 +522,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (!forallOp || forallOp != containingOp) { diag.attachNote(tileableProducer->getLoc()) << "could not find a use by the containing op: " << *tileableProducer; - return nullptr; + return {}; } // Search the producer slices accessed within the containing @@ -542,7 +544,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (itBBArgUsers == bbArg.getUsers().end()) { diag.attachNote(containingOp->getLoc()) << "could not find fusion opportunity for bbArg: " << bbArg; - return nullptr; + return {}; } auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers); @@ -562,7 +564,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( destinationTensors))) { diag.attachNote(tileableProducer->getLoc()) << "failed to get destination tensors for: " << *tileableProducer; - return nullptr; + return {}; } IRMapping bvm; @@ -573,21 +575,19 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); // Tile the producer. - FailureOr<Value> tiledProducer = + FailureOr<TilingResult> tileAndFuseResult = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast<RankedTensorType>() @@ -601,7 +601,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( destinationTensors.front()); }); - return fusedOp; + return tileAndFuseResult->tiledOps; } static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, @@ -714,21 +714,21 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. - Operation *tiled = + SmallVector<Operation *> tiledOps = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); - if (tiled) { + if (!tiledOps.empty()) { LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); - fusedOps.push_back(tiled); + fusedOps.append(tiledOps); continue; } - Operation *tiledContainingOpOperand = + SmallVector<Operation *> tiledContainingOpOperand = tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); - if (tiledContainingOpOperand) { + if (!tiledContainingOpOperand.empty()) { LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" << *containingOp); - fusedOps.push_back(tiledContainingOpOperand); + fusedOps.append(tiledContainingOpOperand); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index c8c9c0bd4af8..e6fce56d4140 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op, offsetsCopy[dimension] = offset; // Create the part as it it were a single tile. - SmallVector<Operation *> tiled = + FailureOr<TilingResult> tilingResult = op.getTiledImplementation(b, offsetsCopy, sizesCopy); - assert(tiled.size() == 1 && "expected a single result from tiling"); - auto part = cast<TilingInterface>(tiled.front()); // Insert the results back and populate the `results` list. - for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) { + for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { SmallVector<OpFoldResult> resultOffsets, resultSizes; - if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, + if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, resultOffsets, resultSizes))) return nullptr; SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), b.getIndexAttr(1)); Value inserted = b.create<tensor::InsertSliceOp>( - loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, + loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); } - - return part; + // TODO: this part can be generalized maybe to not expect a single op. + assert(tilingResult->tiledOps.size() == 1 && + "expected split part to return a single tiled operation"); + return cast<TilingInterface>(tilingResult->tiledOps[0]); } std::pair<TilingInterface, TilingInterface> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 62eef97a1744..1e404cabbb51 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl( } // 4. Tile the cloned op and delete the clone. - SmallVector<Operation *> tiledOps = + FailureOr<TilingResult> tilingResult = cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets, tiledSizes); b.eraseOp(clonedOp); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); } // 5. Parallel insert back into the result tensor. @@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( // 5. Tile the cloned op and delete the clone. if (tileSizes.empty()) { - SmallVector<Operation *> tiledOps = + FailureOr<TilingResult> tilingResult = cast<TilingInterface>(clonedOp).getTiledImplementation( b, tiledOffsets, tiledSizes); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - tilingResults = tiledOp->getResults(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; } else { LinalgTilingOptions options; FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index cfc27ca44e42..676d6330cde3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -111,7 +111,7 @@ struct LinalgOpTilingInterface } // Instantiate the tiled implementation of the operation. - SmallVector<Operation *> + FailureOr<TilingResult> getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { @@ -129,7 +129,7 @@ struct LinalgOpTilingInterface Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); - return {tiledOp}; + return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())}; } // Return the details of the output tile generated by the tiled @@ -160,10 +160,10 @@ struct LinalgOpTilingInterface return success(); } - FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes) const { + FailureOr<TilingResult> + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { auto linalgOp = cast<LinalgOp>(op); // Check that the indexing map used for the output is a projected @@ -197,12 +197,15 @@ struct LinalgOpTilingInterface iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; } - SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation( - b, iterationTileOffsets, iterationTileSizes); - if (tiledOp.size() != 1) + FailureOr<TilingResult> tilingResult = + tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, + iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); - return tiledOp[0]->getResult(resultNumber); + return TilingResult{ + tilingResult->tiledOps, + SmallVector<Value>{tilingResult->tiledValues[resultNumber]}}; } LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 17c46182eb5d..e001f59b21e9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( return failure(); } - Operation *tiledPadOp = + FailureOr<TilingResult> tilingResult = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); + if (failed(tilingResult)) + return failure(); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); + rewriter.replaceOp(sliceOp, tilingResult->tiledValues); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 915e4b4ed1c5..6706f5466283 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -251,18 +251,20 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder, /// a destination passing style op. static SmallVector<Value> yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, - Operation *tiledOp, + TilingResult tilingResult, ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, ArrayRef<SmallVector<OpFoldResult>> tileSizesList, MutableArrayRef<scf::ForOp> loops) { SmallVector<Value> replacements = - yieldTiledValues(rewriter, initValues, tiledOp->getResults(), + yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, tileOffsetsList, tileSizesList, loops); - if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { - auto innerMostLoop = loops.back(); - SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); - updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, - innerMostLoop.getRegionIterArgs()); + for (auto tiledOp : tilingResult.tiledOps) { + if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { + auto innerMostLoop = loops.back(); + SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, + innerMostLoop.getRegionIterArgs()); + } } return replacements; } @@ -345,9 +347,9 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, if (!tilingResult.loops.empty()) rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); - SmallVector<Operation *> tiledImplementation = + FailureOr<TilingResult> tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOps.append(tiledImplementation); + tilingResult.tiledOps.append(tiledImplementation->tiledOps); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -356,9 +358,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, // If loops are empty, the tiled op is used as the replacement for the untiled // op. if (tilingResult.loops.empty()) { - tilingResult.replacements = llvm::to_vector( - llvm::map_range(tiledImplementation[0]->getResults(), - [](OpResult result) -> Value { return result; })); + tilingResult.replacements = tiledImplementation->tiledValues; return tilingResult; } @@ -384,7 +384,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, return rewriter.notifyMatchFailure(op, "failed to get destinations"); tilingResult.replacements = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOps.back(), + rewriter, destinationTensors, tiledImplementation.value(), resultOffsetsList, resultSizesList, tilingResult.loops); LLVM_DEBUG({ @@ -523,12 +523,13 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, // 2. Generate the tiled implementation of the producer of the source OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(candidateSliceOp); - FailureOr<Value> fusedProducerValue = + FailureOr<TilingResult> tileAndFuseResult = tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, fusableProducer); - if (failed(fusedProducerValue)) + if (failed(tileAndFuseResult)) return std::nullopt; - rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value()); + rewriter.replaceAllUsesWith(candidateSliceOp, + tileAndFuseResult->tiledValues[0]); // 3. If the slice is for a destination operand, for example, // @@ -592,8 +593,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, outerMostLoop.setIterArg(iterArgNumber.value(), dstOp.getTiedOpOperand(fusableProducer)->get()); } - if (auto dstOp = fusedProducerValue.value() - .getDefiningOp<DestinationStyleOpInterface>()) { + for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { + auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp); + if (!dstOp) + continue; scf::ForOp innerMostLoop = loops.back(); updateDestinationOperandsForTiledOp( rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), @@ -601,7 +604,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, } } return scf::SCFFuseProducerOfSliceResult{fusableProducer, - fusedProducerValue.value()}; + tileAndFuseResult->tiledValues[0]}; } /// Reconstruct the fused producer from within the tiled-and-fused code. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 1c4db01dc8f2..0faa29ade804 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -46,15 +46,15 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { return loopRanges; } - SmallVector<Operation *> + FailureOr<TilingResult> getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { - Operation *result = + FailureOr<TilingResult> result = tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); - if (!result) - return {}; - return {result}; + if (failed(result)) + return failure(); + return result.value(); } LogicalResult @@ -117,7 +117,7 @@ struct PackOpTiling return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); } - SmallVector<Operation *> + FailureOr<TilingResult> getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { @@ -192,7 +192,8 @@ struct PackOpTiling Operation *tiledPackOp = b.create<PackOp>( loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); - return {tiledPackOp}; + return TilingResult{{tiledPackOp}, + SmallVector<Value>(tiledPackOp->getResults())}; } LogicalResult @@ -353,7 +354,7 @@ struct UnPackOpTiling /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we /// can get the actual result. - SmallVector<Operation *> + FailureOr<TilingResult> getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { @@ -412,12 +413,13 @@ struct UnPackOpTiling loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) - return {tiledUnpackOp}; + return TilingResult{{tiledUnpackOp}, + SmallVector<Value>(tiledUnpackOp->getResults())}; - Operation *extractSlice = + auto extractSlice = b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - return {tiledUnpackOp, extractSlice}; + return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; } LogicalResult @@ -431,26 +433,29 @@ struct UnPackOpTiling return success(); } - FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes) const { - return getTiledImplementation(op, b, offsets, sizes) - .back() - ->getResult(resultNumber); + FailureOr<TilingResult> + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { + FailureOr<TilingResult> tilingResult = + getTiledImplementation(op, b, offsets, sizes); + if (failed(tilingResult)) + return failure(); + return tilingResult.value(); } }; } // namespace -Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, - ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, - bool generateZeroSliceGuard) { +FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, + tensor::PadOp padOp, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, + bool generateZeroSliceGuard) { // Only constant padding value supported. Value padValue = padOp.getConstantPaddingValue(); if (!padValue) - return nullptr; + return failure(); // Helper variables and functions for various arithmetic operations. These // are used extensively for computing new offset/length and padding values. @@ -584,10 +589,9 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, RankedTensorType::get(shape, padOp.getResultType().getElementType()); // Insert cast to ensure that types match. (May be folded away.) - auto castResult = [&](Operation *op) -> Operation * { - Value val = op->getResult(0); + auto castResult = [&](Value val) -> Value { if (resultType == val.getType()) - return op; + return val; return b.create<tensor::CastOp>(loc, resultType, val); }; @@ -601,7 +605,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, [&](OpBuilder &builder, Location gLoc, ValueRange indices) { builder.create<tensor::YieldOp>(gLoc, padValue); }); - return castResult(generateOp); + return generateOp; }; // Emit a SliceOp and a PadOp. Should not be used in cases where @@ -617,30 +621,38 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); // Cast result and return. - return castResult(newPadOp); + return newPadOp; }; // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that // the original data source x is not used. - if (hasZeroLen) - return createGenerateOp(); + if (hasZeroLen) { + Operation *generateOp = createGenerateOp(); + return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; + } // If there are dynamic dimensions: Generate an scf.if check to avoid // creating SliceOps with result dimensions of size 0 at runtime. if (generateZeroSliceGuard && dynHasZeroLenCond) { + Operation *thenOp; + Operation *elseOp; auto result = b.create<scf::IfOp>( loc, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0)); + thenOp = createGenerateOp(); + b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0)); + elseOp = createPadOfExtractSlice(); + b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); }); - return result; + return TilingResult{{result}, SmallVector<Value>(result->getResults())}; } - return createPadOfExtractSlice(); + + Operation *newPadOp = createPadOfExtractSlice(); + return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; } void mlir::tensor::registerTilingInterfaceExternalModels( diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 65176ed7b9e7..40d79c205381 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -20,7 +20,7 @@ using namespace mlir; -FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer( +FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { auto producerOp = dyn_cast<TilingInterface>(producer.getOwner()); if (!producerOp) @@ -32,7 +32,7 @@ FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer( })) return failure(); - FailureOr<Value> tiledResult = producerOp.generateResultTileValue( + FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue( builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), sliceOp.getMixedSizes()); if (failed(tiledResult)) |