summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorAnlun Xu <anlunx@google.com>2023-04-30 18:16:11 -0700
committerAnlun Xu <anlunx@google.com>2023-05-16 14:56:33 -0700
commit6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea (patch)
treee6aa6557a17e6ee132927bf39998029ba9f71fe0 /mlir/lib
parent62a2feff5784bcee3c7b037501956552acdf736c (diff)
downloadllvm-6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea.tar.gz
[mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149564
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp103
1 files changed, 102 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ca27794b64c1..a16ab660e931 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -386,6 +386,106 @@ public:
};
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
+struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
+public:
+ using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value srcTensor = op.getSource();
+ const auto srcTp = getSparseTensorType(srcTensor);
+ const auto dstTp = getSparseTensorType(op.getResult());
+
+ if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
+ !dstTp.hasStaticDimShape())
+ return failure();
+
+ SmallVector<Value> srcSizes;
+ sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+ SmallVector<Value> dstSizes;
+ for (Dimension d : dstTp.getDimShape())
+ dstSizes.push_back(constantIndex(rewriter, loc, d));
+
+ Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
+ // Only need an unordered COO buffer if input and output are not sorted
+ // in the same way.
+ Type bufferTp =
+ srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+ ? dstTp.getRankedTensorType()
+ : getUnorderedCOOFromType(dstTp);
+ SmallVector<Value> dynSizes;
+ Value buffer = rewriter
+ .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
+ nnz, Attribute())
+ .getResult();
+
+ // Convert src coordinates to dst coordinates by first collapsing it to 1D
+ // and then expand it to the match the rank of the destination tensor.
+ // Implemented as follows:
+ // foreach srcCoords %srcTensor
+ // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
+ // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
+ // insert expandedCoords, %buffer
+ //
+ // followed by an optional
+ // %t = sparse_tensor.cast %tmp
+ // depending on whether the input/output are sorted in the same way.
+ const auto encSrc = srcTp.getEncoding();
+ ForeachOp foreachOp = rewriter.create<ForeachOp>(
+ loc, srcTensor, buffer,
+ [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
+ ValueRange reduc) {
+ const Dimension srcRank = srcTp.getDimRank();
+ SmallVector<Value> srcDcvs;
+ srcDcvs.reserve(srcRank);
+ for (Dimension d = 0; d < srcRank; d++) {
+ // FIXME: `toStoredDim` is deprecated
+ Level lvl = toStoredDim(encSrc, d);
+ srcDcvs.push_back(srcLcvs[lvl]);
+ }
+
+ Value collapsed_size = constantIndex(builder, loc, 1);
+ for (Dimension d = 0; d < srcRank; d++)
+ collapsed_size =
+ builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
+ SmallVector<Value, 1> collapsedSizes = {collapsed_size};
+
+ ReassociationIndices collapse_indices;
+ for (Dimension i = 0; i < srcRank; i++)
+ collapse_indices.push_back(i);
+ SmallVector<ReassociationIndices, 1> collapse_reassociation = {
+ collapse_indices};
+ SmallVector<Value, 1> collapsedDcvs;
+ reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs,
+ collapsedSizes, collapsedDcvs);
+
+ ReassociationIndices expand_indices;
+ for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+ expand_indices.push_back(i);
+ SmallVector<ReassociationIndices, 1> expand_reassociation = {
+ expand_indices};
+ SmallVector<Value> dstDcvs;
+ reshapeCvs(builder, loc, expand_reassociation, collapsedSizes,
+ collapsedDcvs, dstSizes, dstDcvs);
+
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
+ });
+
+ Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+ if (bufferTp != dstTp) {
+ auto dstRTT = dstTp.getRankedTensorType();
+ Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
+ rewriter.create<DeallocTensorOp>(loc, t);
+ t = converted;
+ }
+ rewriter.replaceOp(op, t);
+ return success();
+ }
+};
+
+/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableForeach,
bool enableConvert) {
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
+ patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
// TODO: If RT not enabled, rewrite concatenate ops, etc here.