diff options
author | Anlun Xu <anlunx@google.com> | 2023-04-30 18:16:11 -0700 |
---|---|---|
committer | Anlun Xu <anlunx@google.com> | 2023-05-16 14:56:33 -0700 |
commit | 6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea (patch) | |
tree | e6aa6557a17e6ee132927bf39998029ba9f71fe0 /mlir/lib | |
parent | 62a2feff5784bcee3c7b037501956552acdf736c (diff) | |
download | llvm-6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea.tar.gz |
[mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D149564
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 103 |
1 files changed, 102 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index ca27794b64c1..a16ab660e931 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -386,6 +386,106 @@ public: }; /// Sparse rewriting rule for sparse-to-sparse reshape operator. +struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { +public: + using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ReshapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value srcTensor = op.getSource(); + const auto srcTp = getSparseTensorType(srcTensor); + const auto dstTp = getSparseTensorType(op.getResult()); + + if (!srcTp.hasEncoding() || !dstTp.hasEncoding() || + !dstTp.hasStaticDimShape()) + return failure(); + + SmallVector<Value> srcSizes; + sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); + SmallVector<Value> dstSizes; + for (Dimension d : dstTp.getDimShape()) + dstSizes.push_back(constantIndex(rewriter, loc, d)); + + Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor); + // Only need an unordered COO buffer if input and output are not sorted + // in the same way. + Type bufferTp = + srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity() + ? dstTp.getRankedTensorType() + : getUnorderedCOOFromType(dstTp); + SmallVector<Value> dynSizes; + Value buffer = rewriter + .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), + nnz, Attribute()) + .getResult(); + + // Convert src coordinates to dst coordinates by first collapsing it to 1D + // and then expand it to the match the rank of the destination tensor. + // Implemented as follows: + // foreach srcCoords %srcTensor + // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) + // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) + // insert expandedCoords, %buffer + // + // followed by an optional + // %t = sparse_tensor.cast %tmp + // depending on whether the input/output are sorted in the same way. + const auto encSrc = srcTp.getEncoding(); + ForeachOp foreachOp = rewriter.create<ForeachOp>( + loc, srcTensor, buffer, + [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, + ValueRange reduc) { + const Dimension srcRank = srcTp.getDimRank(); + SmallVector<Value> srcDcvs; + srcDcvs.reserve(srcRank); + for (Dimension d = 0; d < srcRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level lvl = toStoredDim(encSrc, d); + srcDcvs.push_back(srcLcvs[lvl]); + } + + Value collapsed_size = constantIndex(builder, loc, 1); + for (Dimension d = 0; d < srcRank; d++) + collapsed_size = + builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]); + SmallVector<Value, 1> collapsedSizes = {collapsed_size}; + + ReassociationIndices collapse_indices; + for (Dimension i = 0; i < srcRank; i++) + collapse_indices.push_back(i); + SmallVector<ReassociationIndices, 1> collapse_reassociation = { + collapse_indices}; + SmallVector<Value, 1> collapsedDcvs; + reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs, + collapsedSizes, collapsedDcvs); + + ReassociationIndices expand_indices; + for (Dimension i = 0; i < dstTp.getDimRank(); i++) + expand_indices.push_back(i); + SmallVector<ReassociationIndices, 1> expand_reassociation = { + expand_indices}; + SmallVector<Value> dstDcvs; + reshapeCvs(builder, loc, expand_reassociation, collapsedSizes, + collapsedDcvs, dstSizes, dstDcvs); + + auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs); + builder.create<sparse_tensor::YieldOp>(loc, t); + }); + + Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true); + if (bufferTp != dstTp) { + auto dstRTT = dstTp.getRankedTensorType(); + Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult(); + rewriter.create<DeallocTensorOp>(loc, t); + t = converted; + } + rewriter.replaceOp(op, t); + return success(); + } +}; + +/// Sparse rewriting rule for sparse-to-sparse reshape operator. template <typename ReshapeOp> struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { public: @@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns, bool enableForeach, bool enableConvert) { patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>, - ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext()); + ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>( + patterns.getContext()); if (enableForeach) patterns.add<ForeachRewriter>(patterns.getContext()); // TODO: If RT not enabled, rewrite concatenate ops, etc here. |