summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2022-12-08 20:03:18 +0000
committerPeiming Liu <peiming@google.com>2022-12-08 23:36:30 +0000
commitfaa75f94f113f2e912a672afcd611ecc79942e6a (patch)
treeae5a67711b67d2e249a8b3dee3f089e7151e104b /mlir/lib/Dialect
parent4efcea95852abe6ed25ae9a2bf8c3a51a1157675 (diff)
downloadllvm-faa75f94f113f2e912a672afcd611ecc79942e6a.tar.gz
[mlir][sparse] reject kernels with non-sparsfiable reduction expression.
To address https://github.com/llvm/llvm-project/issues/59394. Reduction on negation of the output tensor is a non-sparsifiable kernel, it creates cyclic dependency. This patch reject those cases instead of crashing. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D139659
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp13
-rw-r--r--mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp126
2 files changed, 139 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 81f3845eebec..8fbbf6aed6b6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -583,6 +583,19 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
std::vector<unsigned> &topSort, unsigned exp,
OpOperand **sparseOut,
unsigned &outerParNest) {
+ // We reject any expression that makes a reduction from `-outTensor`, as those
+ // expression create dependency between the current iteration (i) and the
+ // previous iteration (i-1). It would then require iterating over the whole
+ // coordinate space, which prevent us from exploiting sparsity for faster
+ // code.
+ for (utils::IteratorType it : op.getIteratorTypesArray()) {
+ if (it == utils::IteratorType::reduction) {
+ if (merger.hasNegateOnOut(exp))
+ return false;
+ break;
+ }
+ }
+
OpOperand *lhs = op.getDpsInitOperand(0);
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 35530bf4ed3b..bf6612021f92 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -18,6 +18,81 @@
namespace mlir {
namespace sparse_tensor {
+enum class ExpArity {
+ kNullary,
+ kUnary,
+ kBinary,
+};
+
+static ExpArity getExpArity(Kind k) {
+ switch (k) {
+ // Leaf.
+ case kTensor:
+ case kInvariant:
+ case kIndex:
+ return ExpArity::kNullary;
+ case kAbsF:
+ case kAbsC:
+ case kAbsI:
+ case kCeilF:
+ case kFloorF:
+ case kSqrtF:
+ case kSqrtC:
+ case kExpm1F:
+ case kExpm1C:
+ case kLog1pF:
+ case kLog1pC:
+ case kSinF:
+ case kSinC:
+ case kTanhF:
+ case kTanhC:
+ case kTruncF:
+ case kExtF:
+ case kCastFS:
+ case kCastFU:
+ case kCastSF:
+ case kCastUF:
+ case kCastS:
+ case kCastU:
+ case kCastIdx:
+ case kTruncI:
+ case kCIm:
+ case kCRe:
+ case kBitCast:
+ case kBinaryBranch:
+ case kUnary:
+ case kSelect:
+ case kNegF:
+ case kNegC:
+ case kNegI:
+ return ExpArity::kUnary;
+ // Binary operations.
+ case kDivF:
+ case kDivC:
+ case kDivS:
+ case kDivU:
+ case kShrS:
+ case kShrU:
+ case kShlI:
+ case kMulF:
+ case kMulC:
+ case kMulI:
+ case kAndI:
+ case kAddF:
+ case kAddC:
+ case kAddI:
+ case kOrI:
+ case kXorI:
+ case kBinary:
+ case kReduce:
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ return ExpArity::kBinary;
+ }
+ llvm_unreachable("unexpected kind");
+}
+
//===----------------------------------------------------------------------===//
// Constructors.
//===----------------------------------------------------------------------===//
@@ -310,6 +385,57 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
return !hasAnySparse(tmp);
}
+bool Merger::expContainsTensor(unsigned e, unsigned t) const {
+ if (tensorExps[e].kind == kTensor)
+ return tensorExps[e].tensor == t;
+
+ switch (getExpArity(tensorExps[e].kind)) {
+ case ExpArity::kNullary:
+ return false;
+ case ExpArity::kUnary: {
+ unsigned op = tensorExps[e].children.e0;
+ if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t)
+ return true;
+ return expContainsTensor(op, t);
+ }
+ case ExpArity::kBinary: {
+ unsigned op1 = tensorExps[e].children.e0;
+ unsigned op2 = tensorExps[e].children.e1;
+ if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) ||
+ (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t))
+ return true;
+ return expContainsTensor(op1, t) || expContainsTensor(op2, t);
+ }
+ }
+ llvm_unreachable("unexpected arity");
+}
+
+bool Merger::hasNegateOnOut(unsigned e) const {
+ switch (tensorExps[e].kind) {
+ case kNegF:
+ case kNegC:
+ case kNegI:
+ return expContainsTensor(tensorExps[e].children.e0, outTensor);
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
+ hasNegateOnOut(tensorExps[e].children.e0);
+ default: {
+ switch (getExpArity(tensorExps[e].kind)) {
+ case ExpArity::kNullary:
+ return false;
+ case ExpArity::kUnary:
+ return hasNegateOnOut(tensorExps[e].children.e0);
+ case ExpArity::kBinary:
+ return hasNegateOnOut(tensorExps[e].children.e0) ||
+ hasNegateOnOut(tensorExps[e].children.e1);
+ }
+ }
+ }
+ llvm_unreachable("unexpected kind");
+}
+
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
switch (tensorExps[e].kind) {
// Leaf.