summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAnlun Xu <anlunx@google.com>2023-04-30 18:16:11 -0700
committerAnlun Xu <anlunx@google.com>2023-05-16 14:56:33 -0700
commit6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea (patch)
treee6aa6557a17e6ee132927bf39998029ba9f71fe0 /mlir
parent62a2feff5784bcee3c7b037501956552acdf736c (diff)
downloadllvm-6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea.tar.gz
[mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149564
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp103
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir45
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir87
3 files changed, 234 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ca27794b64c1..a16ab660e931 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -386,6 +386,106 @@ public:
};
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
+struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
+public:
+ using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value srcTensor = op.getSource();
+ const auto srcTp = getSparseTensorType(srcTensor);
+ const auto dstTp = getSparseTensorType(op.getResult());
+
+ if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
+ !dstTp.hasStaticDimShape())
+ return failure();
+
+ SmallVector<Value> srcSizes;
+ sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+ SmallVector<Value> dstSizes;
+ for (Dimension d : dstTp.getDimShape())
+ dstSizes.push_back(constantIndex(rewriter, loc, d));
+
+ Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
+ // Only need an unordered COO buffer if input and output are not sorted
+ // in the same way.
+ Type bufferTp =
+ srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+ ? dstTp.getRankedTensorType()
+ : getUnorderedCOOFromType(dstTp);
+ SmallVector<Value> dynSizes;
+ Value buffer = rewriter
+ .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
+ nnz, Attribute())
+ .getResult();
+
+ // Convert src coordinates to dst coordinates by first collapsing it to 1D
+ // and then expand it to the match the rank of the destination tensor.
+ // Implemented as follows:
+ // foreach srcCoords %srcTensor
+ // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
+ // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
+ // insert expandedCoords, %buffer
+ //
+ // followed by an optional
+ // %t = sparse_tensor.cast %tmp
+ // depending on whether the input/output are sorted in the same way.
+ const auto encSrc = srcTp.getEncoding();
+ ForeachOp foreachOp = rewriter.create<ForeachOp>(
+ loc, srcTensor, buffer,
+ [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
+ ValueRange reduc) {
+ const Dimension srcRank = srcTp.getDimRank();
+ SmallVector<Value> srcDcvs;
+ srcDcvs.reserve(srcRank);
+ for (Dimension d = 0; d < srcRank; d++) {
+ // FIXME: `toStoredDim` is deprecated
+ Level lvl = toStoredDim(encSrc, d);
+ srcDcvs.push_back(srcLcvs[lvl]);
+ }
+
+ Value collapsed_size = constantIndex(builder, loc, 1);
+ for (Dimension d = 0; d < srcRank; d++)
+ collapsed_size =
+ builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
+ SmallVector<Value, 1> collapsedSizes = {collapsed_size};
+
+ ReassociationIndices collapse_indices;
+ for (Dimension i = 0; i < srcRank; i++)
+ collapse_indices.push_back(i);
+ SmallVector<ReassociationIndices, 1> collapse_reassociation = {
+ collapse_indices};
+ SmallVector<Value, 1> collapsedDcvs;
+ reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs,
+ collapsedSizes, collapsedDcvs);
+
+ ReassociationIndices expand_indices;
+ for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+ expand_indices.push_back(i);
+ SmallVector<ReassociationIndices, 1> expand_reassociation = {
+ expand_indices};
+ SmallVector<Value> dstDcvs;
+ reshapeCvs(builder, loc, expand_reassociation, collapsedSizes,
+ collapsedDcvs, dstSizes, dstDcvs);
+
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
+ });
+
+ Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+ if (bufferTp != dstTp) {
+ auto dstRTT = dstTp.getRankedTensorType();
+ Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
+ rewriter.create<DeallocTensorOp>(loc, t);
+ t = converted;
+ }
+ rewriter.replaceOp(op, t);
+ return success();
+ }
+};
+
+/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableForeach,
bool enableConvert) {
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
+ patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
new file mode 100644
index 000000000000..369044c38f76
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --cse --canonicalize | FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// CHECK: func.func @sparse_reshape(
+// CHECK-SAME: %[[S:.*]]:
+// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK: %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK: %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK: %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK: %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
+// CHECK: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK: %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK: %[[T:.*]] = arith.muli %[[SI0]], %[[C25]] : index
+// CHECK: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
+// CHECK: %[[D:.*]] = arith.divui %[[DI]], %[[C10]] : index
+// CHECK: %[[R:.*]] = arith.remui %[[DI]], %[[C10]] : index
+// CHECK: %[[R1:.*]] = sparse_tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[D]], %[[R]]]
+// CHECK: scf.yield %[[R1]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_1]]
+// CHECK: }
+// CHECK: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK: return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+//
+func.func @sparse_reshape(%arg0: tensor<4x25xf64, #SparseMatrix>) -> tensor<10x10xf64, #SparseMatrix> {
+ %shape = arith.constant dense <[ 10, 10 ]> : tensor<2xi32>
+ %0 = tensor.reshape %arg0(%shape) :
+ (tensor<4x25xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<10x10xf64, #SparseMatrix>
+ return %0 : tensor<10x10xf64, #SparseMatrix>
+} \ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
new file mode 100644
index 000000000000..4945294f727e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
@@ -0,0 +1,87 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+
+#SparseVector = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"]
+}>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed", "compressed"]
+}>
+
+#Sparse3dTensor = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed", "compressed", "compressed"]
+}>
+
+module {
+
+ func.func @reshape0(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> {
+ %shape = arith.constant dense <[ 2, 6 ]> : tensor<2xi32>
+ %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<2x6xf64, #SparseMatrix>
+ return %0 : tensor<2x6xf64, #SparseMatrix>
+ }
+
+ func.func @reshape1(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
+ %shape = arith.constant dense <[ 12 ]> : tensor<1xi32>
+ %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<1xi32>) -> tensor<12xf64, #SparseVector>
+ return %0 : tensor<12xf64, #SparseVector>
+ }
+
+ func.func @reshape2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> {
+ %shape = arith.constant dense <[ 2, 3, 2 ]> : tensor<3xi32>
+ %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<3xi32>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+ return %0 : tensor<2x3x2xf64, #Sparse3dTensor>
+ }
+
+
+ func.func @entry() {
+ %m = arith.constant dense <[ [ 1.1, 0.0, 1.3, 0.0 ],
+ [ 2.1, 0.0, 2.3, 0.0 ],
+ [ 3.1, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64>
+ %sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix>
+
+ %reshaped0 = call @reshape0(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix>
+ %reshaped1 = call @reshape1(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
+ %reshaped2 = call @reshape2(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+
+ %c0 = arith.constant 0 : index
+ %df = arith.constant -1.0 : f64
+
+ // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+ %b0 = sparse_tensor.values %reshaped0: tensor<2x6xf64, #SparseMatrix> to memref<?xf64>
+ %v0 = vector.transfer_read %b0[%c0], %df: memref<?xf64>, vector<12xf64>
+ vector.print %v0 : vector<12xf64>
+
+ // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+ %b1 = sparse_tensor.values %reshaped1: tensor<12xf64, #SparseVector> to memref<?xf64>
+ %v1 = vector.transfer_read %b1[%c0], %df: memref<?xf64>, vector<12xf64>
+ vector.print %v1 : vector<12xf64>
+
+ // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+ %b2 = sparse_tensor.values %reshaped2: tensor<2x3x2xf64, #Sparse3dTensor> to memref<?xf64>
+ %v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<12xf64>
+ vector.print %v2: vector<12xf64>
+
+ bufferization.dealloc_tensor %sm : tensor<3x4xf64, #SparseMatrix>
+ bufferization.dealloc_tensor %reshaped0 : tensor<2x6xf64, #SparseMatrix>
+ bufferization.dealloc_tensor %reshaped1 : tensor<12xf64, #SparseVector>
+ bufferization.dealloc_tensor %reshaped2 : tensor<2x3x2xf64, #Sparse3dTensor>
+
+ return
+ }
+
+} \ No newline at end of file