diff options
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor')
9 files changed, 42 insertions, 48 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 2def7ccfba94..22d6304dcb41 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -130,23 +130,22 @@ Type SparseTensorEncodingAttr::getCrdType() const { } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { - return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(), - AffineMap(), AffineMap(), getPosWidth(), + return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), AffineMap(), + AffineMap(), getPosWidth(), getCrdWidth()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { - return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(), - getDimOrdering(), getHigherOrdering(), 0, - 0); + return SparseTensorEncodingAttr::get( + getContext(), getLvlTypes(), getDimOrdering(), getHigherOrdering(), 0, 0); } bool SparseTensorEncodingAttr::isAllDense() const { - return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT); + return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT); } bool SparseTensorEncodingAttr::isAllOrdered() const { - return !getImpl() || llvm::all_of(getDimLevelType(), isOrderedDLT); + return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT); } bool SparseTensorEncodingAttr::hasIdDimOrdering() const { @@ -155,14 +154,14 @@ bool SparseTensorEncodingAttr::hasIdDimOrdering() const { Level SparseTensorEncodingAttr::getLvlRank() const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); - return getDimLevelType().size(); + return getLvlTypes().size(); } DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const { if (!getImpl()) return DimLevelType::Dense; assert(l < getLvlRank() && "Level is out of bounds"); - return getDimLevelType()[l]; + return getLvlTypes()[l]; } std::optional<uint64_t> @@ -243,9 +242,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { StringRef attrName; // Exactly 6 keys. - SmallVector<StringRef, 6> keys = {"dimLevelType", "dimOrdering", - "higherOrdering", "posWidth", - "crdWidth", "slice"}; + SmallVector<StringRef, 6> keys = {"lvlTypes", "dimOrdering", "higherOrdering", + "posWidth", "crdWidth", "slice"}; while (succeeded(parser.parseOptionalKeyword(&attrName))) { if (!llvm::is_contained(keys, attrName)) { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; @@ -258,7 +256,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { // cost of the `is_contained` check above. Should instead use some // "find" function that returns the index into `keys` so that we can // dispatch on that instead. - if (attrName == "dimLevelType") { + if (attrName == "lvlTypes") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)); auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr); @@ -336,8 +334,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { // Print the struct-like storage in dictionary fashion. - printer << "<{ dimLevelType = [ "; - llvm::interleaveComma(getDimLevelType(), printer, [&](DimLevelType dlt) { + printer << "<{ lvlTypes = [ "; + llvm::interleaveComma(getLvlTypes(), printer, [&](DimLevelType dlt) { printer << "\"" << toMLIRString(dlt) << "\""; }); printer << " ]"; @@ -366,7 +364,7 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { LogicalResult SparseTensorEncodingAttr::verify( function_ref<InFlightDiagnostic()> emitError, - ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, + ArrayRef<DimLevelType> lvlTypes, AffineMap dimOrdering, AffineMap higherOrdering, unsigned posWidth, unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) { if (!acceptBitWidth(posWidth)) @@ -378,7 +376,7 @@ LogicalResult SparseTensorEncodingAttr::verify( // the `getLvlRank` method is the length of the level-types array, // since it must always be provided and have full rank; therefore we // use that same source-of-truth here. - const Level lvlRank = dimLevelType.size(); + const Level lvlRank = lvlTypes.size(); if (lvlRank == 0) return emitError() << "expected a non-empty array for level types"; if (dimOrdering) { @@ -415,9 +413,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding( function_ref<InFlightDiagnostic()> emitError) const { // Check structural integrity. In particular, this ensures that the // level-rank is coherent across all the fields. - RETURN_FAILURE_IF_FAILED(verify(emitError, getDimLevelType(), - getDimOrdering(), getHigherOrdering(), - getPosWidth(), getCrdWidth(), getDimSlices())) + RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimOrdering(), + getHigherOrdering(), getPosWidth(), + getCrdWidth(), getDimSlices())) // Check integrity with tensor type specifics. In particular, we // need only check that the dimension-rank of the tensor agrees with // the dimension-rank of the encoding. @@ -496,14 +494,14 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt, // An unordered and non-unique compressed level at beginning. // If this is also the last level, then it is unique. lvlTypes.push_back( - *getDimLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); + *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); if (lvlRank > 1) { // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, - *getDimLevelType(LevelFormat::Singleton, ordered, false)); + *buildLevelType(LevelFormat::Singleton, ordered, false)); // Ends by a unique singleton level unless the lvlRank is 1. - lvlTypes.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true)); + lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); } // TODO: Maybe pick the bitwidth based on input/output tensors (probably the @@ -580,8 +578,8 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) { static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector<DimLevelType> dlts; - for (auto dlt : enc.getDimLevelType()) - dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true)); + for (auto dlt : enc.getLvlTypes()) + dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true)); return SparseTensorEncodingAttr::get( enc.getContext(), dlts, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h index d9ef20220cae..3186889b7729 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -79,11 +79,9 @@ public: const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); } DimLevelType dlt(TensorId t, LoopId i) const { - return latticeMerger.getDimLevelType(t, i); - } - DimLevelType dlt(TensorLoopId b) const { - return latticeMerger.getDimLevelType(b); + return latticeMerger.getLvlType(t, i); } + DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getLvlType(b); } // // LoopEmitter delegates. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index d61e54505678..a50e337def72 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -288,7 +288,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { const auto enc = stt.getEncoding(); isSparseSlices[tid] = enc.isSlice(); - for (auto lvlTp : enc.getDimLevelType()) + for (auto lvlTp : enc.getLvlTypes()) lvlTypes[tid].push_back(lvlTp); } else { lvlTypes[tid].assign(lvlRank, DimLevelType::Dense); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e729f725689d..0005c4c6a969 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1159,7 +1159,7 @@ public: // TODO: We should check these in ExtractSliceOp::verify. if (!srcEnc || !dstEnc || !dstEnc.isSlice()) return failure(); - assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType()); + assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes()); assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering()); assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering()); assert(srcEnc.getPosWidth() == dstEnc.getPosWidth()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 906f700cfc47..4636615ed24b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -205,7 +205,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, SparseTensorType stt) { SmallVector<Value> lvlTypes; lvlTypes.reserve(stt.getLvlRank()); - for (const auto dlt : stt.getEncoding().getDimLevelType()) + for (const auto dlt : stt.getEncoding().getLvlTypes()) lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); return allocaBuffer(builder, loc, lvlTypes); } @@ -565,7 +565,7 @@ static void genSparseCOOIterationLoop( rewriter.setInsertionPointToStart(after); const bool hasDenseDim = - llvm::any_of(stt.getEncoding().getDimLevelType(), isDenseDLT); + llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT); if (hasDenseDim) { Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); Value isZero = genIsNonzero(rewriter, loc, elemV); @@ -880,11 +880,11 @@ public: break; case SparseToSparseConversionStrategy::kDirect: useDirectConversion = true; - assert(canUseDirectConversion(dstEnc.getDimLevelType()) && + assert(canUseDirectConversion(dstEnc.getLvlTypes()) && "Unsupported target for direct sparse-to-sparse conversion"); break; case SparseToSparseConversionStrategy::kAuto: - useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType()); + useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes()); break; } if (useDirectConversion) { @@ -896,7 +896,7 @@ public: // method calls can share most parameters, while still providing // the correct sparsity information to either of them. const auto mixedEnc = SparseTensorEncodingAttr::get( - op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(), + op->getContext(), dstEnc.getLvlTypes(), dstEnc.getDimOrdering(), dstEnc.getHigherOrdering(), srcEnc.getPosWidth(), srcEnc.getCrdWidth()); // TODO: This is the only place where `kToCOO` (or `kToIterator`) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index a16ab660e931..6ee1c1b3dc49 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -44,8 +44,7 @@ static bool isZeroValue(Value val) { // Helper to detect a sparse tensor type operand. static bool isSparseTensor(OpOperand *op) { auto enc = getSparseTensorEncoding(op->get().getType()); - return enc && - llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed); + return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed); } // Helper method to find zero/uninitialized allocation. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp index f45e3253adb0..a47d26e1b959 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -134,7 +134,7 @@ void sparse_tensor::foreachFieldInSparseTensor( if (!(callback(fidx, kind, dim, dlt))) \ return; - const auto lvlTypes = enc.getDimLevelType(); + const auto lvlTypes = enc.getLvlTypes(); const Level lvlRank = enc.getLvlRank(); const Level cooStart = getCOOStart(enc); const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 681ba21dd4a3..9c2465d25737 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -232,7 +232,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, switch (a.getKind()) { case AffineExprKind::DimId: { const LoopId idx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition()); - if (!isUndefDLT(merger.getDimLevelType(tid, idx))) + if (!isUndefDLT(merger.getLvlType(tid, idx))) return false; // used more than once if (setLvlFormat) @@ -243,7 +243,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, case AffineExprKind::Mul: case AffineExprKind::Constant: { if (!isDenseDLT(dlt) && setLvlFormat) { - assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx))); + assert(isUndefDLT(merger.getLvlType(tid, filterLdx))); // Use a filter loop for sparse affine expression. merger.setLevelAndType(tid, filterLdx, lvl, dlt); ++filterLdx; @@ -287,7 +287,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, switch (a.getKind()) { case AffineExprKind::DimId: { const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition()); - if (!isUndefDLT(merger.getDimLevelType(tensor, ldx))) + if (!isUndefDLT(merger.getLvlType(tensor, ldx))) return false; // used more than once, e.g., A[i][i] // TODO: Generalizes the following two cases. A[i] (with trivial index @@ -624,8 +624,7 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa<AffineDimExpr>() && - !isDenseDLT(enc.getDimLevelType()[lvl])); + assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlTypes()[lvl])); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); // Now that the ordering of affine expression is captured by filter // loop idx, we only need to ensure the affine ordering against filter @@ -1922,7 +1921,7 @@ private: // auto srcTp = getRankedTensorType(tval); auto dstEnc = SparseTensorEncodingAttr::get( - getContext(), srcEnc.getDimLevelType(), + getContext(), srcEnc.getLvlTypes(), permute(env, env.op().getMatchingIndexingMap(t)), // new order srcEnc.getHigherOrdering(), srcEnc.getPosWidth(), srcEnc.getCrdWidth()); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index ae31af0cc572..c546a7f5e1c5 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -405,7 +405,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { // Starts resetting from a dense level, so that the first bit (if kept) // is not undefined level-type. for (unsigned b = 0; b < be; b++) { - if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) { + if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) { offset = be - b - 1; // relative to the end break; } @@ -417,7 +417,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { b = b == 0 ? be - 1 : b - 1, i++) { // Slice on dense level has `locate` property as well, and can be optimized. if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { - const auto dlt = getDimLevelType(b); + const auto dlt = getLvlType(b); if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) { if (reset) simple.reset(b); @@ -584,7 +584,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const { bool Merger::hasAnySparse(const BitVector &bits) const { for (TensorLoopId b : bits.set_bits()) { - const auto dlt = getDimLevelType(b); + const auto dlt = getLvlType(b); if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) return true; } |