summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp38
1 files changed, 22 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index bbf0065b5329..f98efae76f86 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -46,7 +46,7 @@ using namespace mlir::sparse_tensor;
namespace {
/// Iteration graph sorting.
-enum SortMask {
+enum class SortMask : unsigned {
// The individual mask bits.
kIncludeDenseOutput = 0x1, // b001
kIncludeDenseInput = 0x2, // b010
@@ -57,17 +57,24 @@ enum SortMask {
kSparseOnly = 0x0, // b000
};
-/// SortMask tests on individual bits.
-inline static bool includeDenseInput(unsigned mask) {
- return mask & SortMask::kIncludeDenseInput;
+inline static bool includesAny(SortMask mask1, SortMask mask2) {
+ return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
}
-inline static bool includeDenseOutput(unsigned mask) {
- return mask & SortMask::kIncludeDenseOutput;
+inline static bool includesDenseInput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseInput);
}
-inline static bool includeUndef(unsigned mask) {
- return mask & SortMask::kIncludeUndef;
+inline static bool includesDenseOutput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseOutput);
+}
+
+inline static bool includesDense(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDense);
+}
+
+inline static bool includesUndef(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeUndef);
}
/// A helper class that visits an affine expression and tries to find an
@@ -456,7 +463,7 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
/// along fixed levels. Even for dense storage formats, however, the natural
/// coordinate order yields innermost unit-stride access with better spatial
/// locality.
-static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
+static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
OpOperand *skip = nullptr) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
@@ -471,18 +478,17 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
const auto enc = getSparseTensorEncoding(t.get().getType());
assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n);
- bool isDenseInput = !enc && env.op().isDpsInput(&t);
- bool isDenseOutput = !enc && !isDenseInput;
-
// Skips dense inputs/outputs when not requested.
- if ((isDenseInput && !includeDenseInput(mask)) ||
- (isDenseOutput && !includeDenseOutput(mask)))
+ const bool isDenseInput = !enc && env.op().isDpsInput(&t);
+ const bool isDenseOutput = !enc && !isDenseInput;
+ if ((isDenseInput && !includesDenseInput(mask)) ||
+ (isDenseOutput && !includesDenseOutput(mask)))
continue;
// Push unrelated loops into sparse iteration space, so these
// will be skipped more often.
// TODO: Do we really need this?
- if (includeUndef(mask)) {
+ if (includesUndef(mask)) {
unsigned tensor = t.getOperandNumber();
for (unsigned i = 0; i < n; i++) {
if (isCompressedDLT(env.dlt(tensor, i)) ||
@@ -540,7 +546,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
// E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
// It is totally fine to have loop sequence d0->d2->d1->d3 instead of
// requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
- if (!(mask & SortMask::kIncludeDense))
+ if (!includesDense(mask))
tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta);
// (d0 + d1) < (d2 + d3), or