summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2023-03-09 20:46:04 +0000
committerPeiming Liu <peiming@google.com>2023-03-09 20:55:03 +0000
commit41089f86e37b213ff9e8e204346fa88fb217404b (patch)
treeb0e122827cd42faff4ba4c2bb8cbdb8abf550856 /mlir/lib
parent4c82050c56926d840e4ccf253ad10e6ae3ee6cc7 (diff)
downloadllvm-41089f86e37b213ff9e8e204346fa88fb217404b.tar.gz
[mlir][sparse] fix bugs when convert coo to coo but with different dim ordering
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D145723
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index cb757ef07889..d5e604d05c00 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -681,14 +681,21 @@ private:
// COO tensor.
// TODO: enhance foreachOp to take ordering to remove the need of a
// temporary COO tensor here.
- const RankedTensorType bufferTp = dstTp.isIdentity()
+ const RankedTensorType bufferTp = dstTp.isIdentity() || fromSparseConst
? dstTp.getRankedTensorType()
: getUnorderedCOOFromTypeWithOrdering(
dstTp, dstTp.getDimToLvlMap());
+ // Only imposes foreach order on dense constant (which will be statically
+ // sorted by the sparse compiler), otherwise the rotated loop sequence
+ // results to bad cache locality.
+ AffineMapAttr foreachOrder = nullptr;
+ if (encDst.getDimOrdering() && fromSparseConst)
+ foreachOrder = AffineMapAttr::get(encDst.getDimOrdering());
+
auto buffer =
rewriter.create<AllocTensorOp>(loc, bufferTp, dynSizes).getResult();
auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, buffer,
+ loc, src, buffer, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
Value input = reduc.front();
@@ -795,7 +802,6 @@ private:
// tensor (e.g., src tensor is not ordered or src tensor haves a different
// dimOrdering).
if (const SparseTensorType srcTp(srcRTT);
- !isUniqueCOOType(srcRTT) &&
!(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) {
// Construct a COO tensor from the src tensor.
// TODO: there may be cases for which more efficiently without