diff options
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index bbf0065b5329..f98efae76f86 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -46,7 +46,7 @@ using namespace mlir::sparse_tensor; namespace { /// Iteration graph sorting. -enum SortMask { +enum class SortMask : unsigned { // The individual mask bits. kIncludeDenseOutput = 0x1, // b001 kIncludeDenseInput = 0x2, // b010 @@ -57,17 +57,24 @@ enum SortMask { kSparseOnly = 0x0, // b000 }; -/// SortMask tests on individual bits. -inline static bool includeDenseInput(unsigned mask) { - return mask & SortMask::kIncludeDenseInput; +inline static bool includesAny(SortMask mask1, SortMask mask2) { + return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2); } -inline static bool includeDenseOutput(unsigned mask) { - return mask & SortMask::kIncludeDenseOutput; +inline static bool includesDenseInput(SortMask mask) { + return includesAny(mask, SortMask::kIncludeDenseInput); } -inline static bool includeUndef(unsigned mask) { - return mask & SortMask::kIncludeUndef; +inline static bool includesDenseOutput(SortMask mask) { + return includesAny(mask, SortMask::kIncludeDenseOutput); +} + +inline static bool includesDense(SortMask mask) { + return includesAny(mask, SortMask::kIncludeDense); +} + +inline static bool includesUndef(SortMask mask) { + return includesAny(mask, SortMask::kIncludeUndef); } /// A helper class that visits an affine expression and tries to find an @@ -456,7 +463,7 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, /// along fixed levels. Even for dense storage formats, however, the natural /// coordinate order yields innermost unit-stride access with better spatial /// locality. -static bool computeIterationGraph(CodegenEnv &env, unsigned mask, +static bool computeIterationGraph(CodegenEnv &env, SortMask mask, OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. @@ -471,18 +478,17 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask, const auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); - bool isDenseInput = !enc && env.op().isDpsInput(&t); - bool isDenseOutput = !enc && !isDenseInput; - // Skips dense inputs/outputs when not requested. - if ((isDenseInput && !includeDenseInput(mask)) || - (isDenseOutput && !includeDenseOutput(mask))) + const bool isDenseInput = !enc && env.op().isDpsInput(&t); + const bool isDenseOutput = !enc && !isDenseInput; + if ((isDenseInput && !includesDenseInput(mask)) || + (isDenseOutput && !includesDenseOutput(mask))) continue; // Push unrelated loops into sparse iteration space, so these // will be skipped more often. // TODO: Do we really need this? - if (includeUndef(mask)) { + if (includesUndef(mask)) { unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) { if (isCompressedDLT(env.dlt(tensor, i)) || @@ -540,7 +546,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask, // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). // It is totally fine to have loop sequence d0->d2->d1->d3 instead of // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. - if (!(mask & SortMask::kIncludeDense)) + if (!includesDense(mask)) tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); // (d0 + d1) < (d2 + d3), or |