summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorwren romano <2998727+wrengr@users.noreply.github.com>2023-05-17 13:09:53 -0700
committerwren romano <2998727+wrengr@users.noreply.github.com>2023-05-17 14:24:09 -0700
commita0615d020a02e252196383439e2c8143c6525e05 (patch)
treeaa308ef0e4c62d7dba3450f0eb4f8f1dffc0f57c /mlir/lib
parent4dc205f016e3dd2eb1182886a77676f24e39e329 (diff)
downloadllvm-a0615d020a02e252196383439e2c8143c6525e05.tar.gz
[mlir][sparse] Renaming the STEA field `dimLevelType` to `lvlTypes`
This commit is part of the migration of towards the new STEA syntax/design. In particular, this commit includes the following changes: * Renaming compiler-internal functions/methods: * `SparseTensorEncodingAttr::{getDimLevelType => getLvlTypes}` * `Merger::{getDimLevelType => getLvlType}` (for consistency) * `sparse_tensor::{getDimLevelType => buildLevelType}` (to help reduce confusion vs actual getter methods) * Renaming external facets to match: * the STEA parser and printer * the C and Python bindings * PyTACO However, the actual renaming of the `DimLevelType` itself (along with all the "dlt" names) will be handled in a separate commit. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D150330
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/DialectSparseTensor.cpp12
-rw-r--r--mlir/lib/CAPI/Dialect/SparseTensor.cpp17
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp48
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h6
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp3
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp6
11 files changed, 55 insertions, 64 deletions
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 0e07f256344f..0f0e676041b2 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -39,30 +39,28 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
mlirAttributeIsASparseTensorEncodingAttr)
.def_classmethod(
"get",
- [](py::object cls,
- std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
+ [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
std::optional<MlirAffineMap> dimOrdering,
std::optional<MlirAffineMap> higherOrdering, int posWidth,
int crdWidth, MlirContext context) {
return cls(mlirSparseTensorEncodingAttrGet(
- context, dimLevelTypes.size(), dimLevelTypes.data(),
+ context, lvlTypes.size(), lvlTypes.data(),
dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
higherOrdering ? *higherOrdering : MlirAffineMap{nullptr},
posWidth, crdWidth));
},
- py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
+ py::arg("cls"), py::arg("lvl_types"), py::arg("dim_ordering"),
py::arg("higher_ordering"), py::arg("pos_width"),
py::arg("crd_width"), py::arg("context") = py::none(),
"Gets a sparse_tensor.encoding from parameters.")
.def_property_readonly(
- "dim_level_types",
+ "lvl_types",
[](MlirAttribute self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
std::vector<MlirSparseTensorDimLevelType> ret;
ret.reserve(lvlRank);
for (int l = 0; l < lvlRank; ++l)
- ret.push_back(
- mlirSparseTensorEncodingAttrGetDimLevelType(self, l));
+ ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
return ret;
})
.def_property_readonly(
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 795ce51ff9f0..8569acf43613 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -47,16 +47,15 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
MlirAttribute mlirSparseTensorEncodingAttrGet(
MlirContext ctx, intptr_t lvlRank,
- MlirSparseTensorDimLevelType const *dimLevelTypes,
- MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth,
- int crdWidth) {
- SmallVector<DimLevelType> cppDimLevelTypes;
- cppDimLevelTypes.reserve(lvlRank);
+ MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimOrdering,
+ MlirAffineMap higherOrdering, int posWidth, int crdWidth) {
+ SmallVector<DimLevelType> cppLvlTypes;
+ cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
- cppDimLevelTypes.push_back(static_cast<DimLevelType>(dimLevelTypes[l]));
+ cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
return wrap(SparseTensorEncodingAttr::get(
- unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering),
- unwrap(higherOrdering), posWidth, crdWidth));
+ unwrap(ctx), cppLvlTypes, unwrap(dimOrdering), unwrap(higherOrdering),
+ posWidth, crdWidth));
}
MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
@@ -73,7 +72,7 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
}
MlirSparseTensorDimLevelType
-mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) {
+mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
return static_cast<MlirSparseTensorDimLevelType>(
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 2def7ccfba94..22d6304dcb41 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -130,23 +130,22 @@ Type SparseTensorEncodingAttr::getCrdType() const {
}
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
- return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(),
- AffineMap(), AffineMap(), getPosWidth(),
+ return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), AffineMap(),
+ AffineMap(), getPosWidth(),
getCrdWidth());
}
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
- return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(),
- getDimOrdering(), getHigherOrdering(), 0,
- 0);
+ return SparseTensorEncodingAttr::get(
+ getContext(), getLvlTypes(), getDimOrdering(), getHigherOrdering(), 0, 0);
}
bool SparseTensorEncodingAttr::isAllDense() const {
- return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT);
+ return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
}
bool SparseTensorEncodingAttr::isAllOrdered() const {
- return !getImpl() || llvm::all_of(getDimLevelType(), isOrderedDLT);
+ return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
}
bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
@@ -155,14 +154,14 @@ bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
Level SparseTensorEncodingAttr::getLvlRank() const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
- return getDimLevelType().size();
+ return getLvlTypes().size();
}
DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
if (!getImpl())
return DimLevelType::Dense;
assert(l < getLvlRank() && "Level is out of bounds");
- return getDimLevelType()[l];
+ return getLvlTypes()[l];
}
std::optional<uint64_t>
@@ -243,9 +242,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
StringRef attrName;
// Exactly 6 keys.
- SmallVector<StringRef, 6> keys = {"dimLevelType", "dimOrdering",
- "higherOrdering", "posWidth",
- "crdWidth", "slice"};
+ SmallVector<StringRef, 6> keys = {"lvlTypes", "dimOrdering", "higherOrdering",
+ "posWidth", "crdWidth", "slice"};
while (succeeded(parser.parseOptionalKeyword(&attrName))) {
if (!llvm::is_contained(keys, attrName)) {
parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
@@ -258,7 +256,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
// cost of the `is_contained` check above. Should instead use some
// "find" function that returns the index into `keys` so that we can
// dispatch on that instead.
- if (attrName == "dimLevelType") {
+ if (attrName == "lvlTypes") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr));
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr);
@@ -336,8 +334,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
// Print the struct-like storage in dictionary fashion.
- printer << "<{ dimLevelType = [ ";
- llvm::interleaveComma(getDimLevelType(), printer, [&](DimLevelType dlt) {
+ printer << "<{ lvlTypes = [ ";
+ llvm::interleaveComma(getLvlTypes(), printer, [&](DimLevelType dlt) {
printer << "\"" << toMLIRString(dlt) << "\"";
});
printer << " ]";
@@ -366,7 +364,7 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
LogicalResult SparseTensorEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
+ ArrayRef<DimLevelType> lvlTypes, AffineMap dimOrdering,
AffineMap higherOrdering, unsigned posWidth, unsigned crdWidth,
ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
if (!acceptBitWidth(posWidth))
@@ -378,7 +376,7 @@ LogicalResult SparseTensorEncodingAttr::verify(
// the `getLvlRank` method is the length of the level-types array,
// since it must always be provided and have full rank; therefore we
// use that same source-of-truth here.
- const Level lvlRank = dimLevelType.size();
+ const Level lvlRank = lvlTypes.size();
if (lvlRank == 0)
return emitError() << "expected a non-empty array for level types";
if (dimOrdering) {
@@ -415,9 +413,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
function_ref<InFlightDiagnostic()> emitError) const {
// Check structural integrity. In particular, this ensures that the
// level-rank is coherent across all the fields.
- RETURN_FAILURE_IF_FAILED(verify(emitError, getDimLevelType(),
- getDimOrdering(), getHigherOrdering(),
- getPosWidth(), getCrdWidth(), getDimSlices()))
+ RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimOrdering(),
+ getHigherOrdering(), getPosWidth(),
+ getCrdWidth(), getDimSlices()))
// Check integrity with tensor type specifics. In particular, we
// need only check that the dimension-rank of the tensor agrees with
// the dimension-rank of the encoding.
@@ -496,14 +494,14 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
// An unordered and non-unique compressed level at beginning.
// If this is also the last level, then it is unique.
lvlTypes.push_back(
- *getDimLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+ *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
if (lvlRank > 1) {
// TODO: it is actually ordered at the level for ordered input.
// Followed by unordered non-unique n-2 singleton levels.
std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
- *getDimLevelType(LevelFormat::Singleton, ordered, false));
+ *buildLevelType(LevelFormat::Singleton, ordered, false));
// Ends by a unique singleton level unless the lvlRank is 1.
- lvlTypes.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true));
+ lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
}
// TODO: Maybe pick the bitwidth based on input/output tensors (probably the
@@ -580,8 +578,8 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
- for (auto dlt : enc.getDimLevelType())
- dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
+ for (auto dlt : enc.getLvlTypes())
+ dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true));
return SparseTensorEncodingAttr::get(
enc.getContext(), dlts,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index d9ef20220cae..3186889b7729 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -79,11 +79,9 @@ public:
const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
DimLevelType dlt(TensorId t, LoopId i) const {
- return latticeMerger.getDimLevelType(t, i);
- }
- DimLevelType dlt(TensorLoopId b) const {
- return latticeMerger.getDimLevelType(b);
+ return latticeMerger.getLvlType(t, i);
}
+ DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
//
// LoopEmitter delegates.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index d61e54505678..a50e337def72 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -288,7 +288,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
- for (auto lvlTp : enc.getDimLevelType())
+ for (auto lvlTp : enc.getLvlTypes())
lvlTypes[tid].push_back(lvlTp);
} else {
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e729f725689d..0005c4c6a969 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1159,7 +1159,7 @@ public:
// TODO: We should check these in ExtractSliceOp::verify.
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
- assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
+ assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 906f700cfc47..4636615ed24b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -205,7 +205,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
SparseTensorType stt) {
SmallVector<Value> lvlTypes;
lvlTypes.reserve(stt.getLvlRank());
- for (const auto dlt : stt.getEncoding().getDimLevelType())
+ for (const auto dlt : stt.getEncoding().getLvlTypes())
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
return allocaBuffer(builder, loc, lvlTypes);
}
@@ -565,7 +565,7 @@ static void genSparseCOOIterationLoop(
rewriter.setInsertionPointToStart(after);
const bool hasDenseDim =
- llvm::any_of(stt.getEncoding().getDimLevelType(), isDenseDLT);
+ llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT);
if (hasDenseDim) {
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
Value isZero = genIsNonzero(rewriter, loc, elemV);
@@ -880,11 +880,11 @@ public:
break;
case SparseToSparseConversionStrategy::kDirect:
useDirectConversion = true;
- assert(canUseDirectConversion(dstEnc.getDimLevelType()) &&
+ assert(canUseDirectConversion(dstEnc.getLvlTypes()) &&
"Unsupported target for direct sparse-to-sparse conversion");
break;
case SparseToSparseConversionStrategy::kAuto:
- useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType());
+ useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes());
break;
}
if (useDirectConversion) {
@@ -896,7 +896,7 @@ public:
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
const auto mixedEnc = SparseTensorEncodingAttr::get(
- op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(),
+ op->getContext(), dstEnc.getLvlTypes(), dstEnc.getDimOrdering(),
dstEnc.getHigherOrdering(), srcEnc.getPosWidth(),
srcEnc.getCrdWidth());
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a16ab660e931..6ee1c1b3dc49 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -44,8 +44,7 @@ static bool isZeroValue(Value val) {
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(OpOperand *op) {
auto enc = getSparseTensorEncoding(op->get().getType());
- return enc &&
- llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed);
+ return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed);
}
// Helper method to find zero/uninitialized allocation.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index f45e3253adb0..a47d26e1b959 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -134,7 +134,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
if (!(callback(fidx, kind, dim, dlt))) \
return;
- const auto lvlTypes = enc.getDimLevelType();
+ const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
const Level cooStart = getCOOStart(enc);
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 681ba21dd4a3..9c2465d25737 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -232,7 +232,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId idx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
- if (!isUndefDLT(merger.getDimLevelType(tid, idx)))
+ if (!isUndefDLT(merger.getLvlType(tid, idx)))
return false; // used more than once
if (setLvlFormat)
@@ -243,7 +243,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
if (!isDenseDLT(dlt) && setLvlFormat) {
- assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx)));
+ assert(isUndefDLT(merger.getLvlType(tid, filterLdx)));
// Use a filter loop for sparse affine expression.
merger.setLevelAndType(tid, filterLdx, lvl, dlt);
++filterLdx;
@@ -287,7 +287,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
- if (!isUndefDLT(merger.getDimLevelType(tensor, ldx)))
+ if (!isUndefDLT(merger.getLvlType(tensor, ldx)))
return false; // used more than once, e.g., A[i][i]
// TODO: Generalizes the following two cases. A[i] (with trivial index
@@ -624,8 +624,7 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!ta.isa<AffineDimExpr>() &&
- !isDenseDLT(enc.getDimLevelType()[lvl]));
+ assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlTypes()[lvl]));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
// Now that the ordering of affine expression is captured by filter
// loop idx, we only need to ensure the affine ordering against filter
@@ -1922,7 +1921,7 @@ private:
//
auto srcTp = getRankedTensorType(tval);
auto dstEnc = SparseTensorEncodingAttr::get(
- getContext(), srcEnc.getDimLevelType(),
+ getContext(), srcEnc.getLvlTypes(),
permute(env, env.op().getMatchingIndexingMap(t)), // new order
srcEnc.getHigherOrdering(), srcEnc.getPosWidth(),
srcEnc.getCrdWidth());
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index ae31af0cc572..c546a7f5e1c5 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -405,7 +405,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
for (unsigned b = 0; b < be; b++) {
- if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) {
+ if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) {
offset = be - b - 1; // relative to the end
break;
}
@@ -417,7 +417,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
b = b == 0 ? be - 1 : b - 1, i++) {
// Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
- const auto dlt = getDimLevelType(b);
+ const auto dlt = getLvlType(b);
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) {
if (reset)
simple.reset(b);
@@ -584,7 +584,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
- const auto dlt = getDimLevelType(b);
+ const auto dlt = getLvlType(b);
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt))
return true;
}