diff options
author | Peiming Liu <peiming@google.com> | 2022-12-08 20:03:18 +0000 |
---|---|---|
committer | Peiming Liu <peiming@google.com> | 2022-12-08 23:36:30 +0000 |
commit | faa75f94f113f2e912a672afcd611ecc79942e6a (patch) | |
tree | ae5a67711b67d2e249a8b3dee3f089e7151e104b /mlir/lib/Dialect | |
parent | 4efcea95852abe6ed25ae9a2bf8c3a51a1157675 (diff) | |
download | llvm-faa75f94f113f2e912a672afcd611ecc79942e6a.tar.gz |
[mlir][sparse] reject kernels with non-sparsfiable reduction expression.
To address https://github.com/llvm/llvm-project/issues/59394.
Reduction on negation of the output tensor is a non-sparsifiable kernel, it creates cyclic dependency.
This patch reject those cases instead of crashing.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D139659
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 13 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp | 126 |
2 files changed, 139 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 81f3845eebec..8fbbf6aed6b6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -583,6 +583,19 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, std::vector<unsigned> &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { + // We reject any expression that makes a reduction from `-outTensor`, as those + // expression create dependency between the current iteration (i) and the + // previous iteration (i-1). It would then require iterating over the whole + // coordinate space, which prevent us from exploiting sparsity for faster + // code. + for (utils::IteratorType it : op.getIteratorTypesArray()) { + if (it == utils::IteratorType::reduction) { + if (merger.hasNegateOnOut(exp)) + return false; + break; + } + } + OpOperand *lhs = op.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 35530bf4ed3b..bf6612021f92 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -18,6 +18,81 @@ namespace mlir { namespace sparse_tensor { +enum class ExpArity { + kNullary, + kUnary, + kBinary, +}; + +static ExpArity getExpArity(Kind k) { + switch (k) { + // Leaf. + case kTensor: + case kInvariant: + case kIndex: + return ExpArity::kNullary; + case kAbsF: + case kAbsC: + case kAbsI: + case kCeilF: + case kFloorF: + case kSqrtF: + case kSqrtC: + case kExpm1F: + case kExpm1C: + case kLog1pF: + case kLog1pC: + case kSinF: + case kSinC: + case kTanhF: + case kTanhC: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kCastIdx: + case kTruncI: + case kCIm: + case kCRe: + case kBitCast: + case kBinaryBranch: + case kUnary: + case kSelect: + case kNegF: + case kNegC: + case kNegI: + return ExpArity::kUnary; + // Binary operations. + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kShrS: + case kShrU: + case kShlI: + case kMulF: + case kMulC: + case kMulI: + case kAndI: + case kAddF: + case kAddC: + case kAddI: + case kOrI: + case kXorI: + case kBinary: + case kReduce: + case kSubF: + case kSubC: + case kSubI: + return ExpArity::kBinary; + } + llvm_unreachable("unexpected kind"); +} + //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// @@ -310,6 +385,57 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) { return !hasAnySparse(tmp); } +bool Merger::expContainsTensor(unsigned e, unsigned t) const { + if (tensorExps[e].kind == kTensor) + return tensorExps[e].tensor == t; + + switch (getExpArity(tensorExps[e].kind)) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: { + unsigned op = tensorExps[e].children.e0; + if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t) + return true; + return expContainsTensor(op, t); + } + case ExpArity::kBinary: { + unsigned op1 = tensorExps[e].children.e0; + unsigned op2 = tensorExps[e].children.e1; + if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) || + (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t)) + return true; + return expContainsTensor(op1, t) || expContainsTensor(op2, t); + } + } + llvm_unreachable("unexpected arity"); +} + +bool Merger::hasNegateOnOut(unsigned e) const { + switch (tensorExps[e].kind) { + case kNegF: + case kNegC: + case kNegI: + return expContainsTensor(tensorExps[e].children.e0, outTensor); + case kSubF: + case kSubC: + case kSubI: + return expContainsTensor(tensorExps[e].children.e1, outTensor) || + hasNegateOnOut(tensorExps[e].children.e0); + default: { + switch (getExpArity(tensorExps[e].kind)) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: + return hasNegateOnOut(tensorExps[e].children.e0); + case ExpArity::kBinary: + return hasNegateOnOut(tensorExps[e].children.e0) || + hasNegateOnOut(tensorExps[e].children.e1); + } + } + } + llvm_unreachable("unexpected kind"); +} + bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { // Leaf. |