diff options
author | Peiming Liu <peiming@google.com> | 2023-03-09 20:46:04 +0000 |
---|---|---|
committer | Peiming Liu <peiming@google.com> | 2023-03-09 20:55:03 +0000 |
commit | 41089f86e37b213ff9e8e204346fa88fb217404b (patch) | |
tree | b0e122827cd42faff4ba4c2bb8cbdb8abf550856 /mlir/lib | |
parent | 4c82050c56926d840e4ccf253ad10e6ae3ee6cc7 (diff) | |
download | llvm-41089f86e37b213ff9e8e204346fa88fb217404b.tar.gz |
[mlir][sparse] fix bugs when convert coo to coo but with different dim ordering
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145723
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index cb757ef07889..d5e604d05c00 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -681,14 +681,21 @@ private: // COO tensor. // TODO: enhance foreachOp to take ordering to remove the need of a // temporary COO tensor here. - const RankedTensorType bufferTp = dstTp.isIdentity() + const RankedTensorType bufferTp = dstTp.isIdentity() || fromSparseConst ? dstTp.getRankedTensorType() : getUnorderedCOOFromTypeWithOrdering( dstTp, dstTp.getDimToLvlMap()); + // Only imposes foreach order on dense constant (which will be statically + // sorted by the sparse compiler), otherwise the rotated loop sequence + // results to bad cache locality. + AffineMapAttr foreachOrder = nullptr; + if (encDst.getDimOrdering() && fromSparseConst) + foreachOrder = AffineMapAttr::get(encDst.getDimOrdering()); + auto buffer = rewriter.create<AllocTensorOp>(loc, bufferTp, dynSizes).getResult(); auto foreachOp = rewriter.create<ForeachOp>( - loc, src, buffer, + loc, src, buffer, foreachOrder, [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { Value input = reduc.front(); @@ -795,7 +802,6 @@ private: // tensor (e.g., src tensor is not ordered or src tensor haves a different // dimOrdering). if (const SparseTensorType srcTp(srcRTT); - !isUniqueCOOType(srcRTT) && !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without |