summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h6
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp3
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp11
7 files changed, 16 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index d9ef20220cae..3186889b7729 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -79,11 +79,9 @@ public:
const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
DimLevelType dlt(TensorId t, LoopId i) const {
- return latticeMerger.getDimLevelType(t, i);
- }
- DimLevelType dlt(TensorLoopId b) const {
- return latticeMerger.getDimLevelType(b);
+ return latticeMerger.getLvlType(t, i);
}
+ DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
//
// LoopEmitter delegates.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index d61e54505678..a50e337def72 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -288,7 +288,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
- for (auto lvlTp : enc.getDimLevelType())
+ for (auto lvlTp : enc.getLvlTypes())
lvlTypes[tid].push_back(lvlTp);
} else {
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e729f725689d..0005c4c6a969 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1159,7 +1159,7 @@ public:
// TODO: We should check these in ExtractSliceOp::verify.
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
- assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
+ assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 906f700cfc47..4636615ed24b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -205,7 +205,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
SparseTensorType stt) {
SmallVector<Value> lvlTypes;
lvlTypes.reserve(stt.getLvlRank());
- for (const auto dlt : stt.getEncoding().getDimLevelType())
+ for (const auto dlt : stt.getEncoding().getLvlTypes())
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
return allocaBuffer(builder, loc, lvlTypes);
}
@@ -565,7 +565,7 @@ static void genSparseCOOIterationLoop(
rewriter.setInsertionPointToStart(after);
const bool hasDenseDim =
- llvm::any_of(stt.getEncoding().getDimLevelType(), isDenseDLT);
+ llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT);
if (hasDenseDim) {
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
Value isZero = genIsNonzero(rewriter, loc, elemV);
@@ -880,11 +880,11 @@ public:
break;
case SparseToSparseConversionStrategy::kDirect:
useDirectConversion = true;
- assert(canUseDirectConversion(dstEnc.getDimLevelType()) &&
+ assert(canUseDirectConversion(dstEnc.getLvlTypes()) &&
"Unsupported target for direct sparse-to-sparse conversion");
break;
case SparseToSparseConversionStrategy::kAuto:
- useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType());
+ useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes());
break;
}
if (useDirectConversion) {
@@ -896,7 +896,7 @@ public:
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
const auto mixedEnc = SparseTensorEncodingAttr::get(
- op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(),
+ op->getContext(), dstEnc.getLvlTypes(), dstEnc.getDimOrdering(),
dstEnc.getHigherOrdering(), srcEnc.getPosWidth(),
srcEnc.getCrdWidth());
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a16ab660e931..6ee1c1b3dc49 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -44,8 +44,7 @@ static bool isZeroValue(Value val) {
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(OpOperand *op) {
auto enc = getSparseTensorEncoding(op->get().getType());
- return enc &&
- llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed);
+ return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed);
}
// Helper method to find zero/uninitialized allocation.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index f45e3253adb0..a47d26e1b959 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -134,7 +134,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
if (!(callback(fidx, kind, dim, dlt))) \
return;
- const auto lvlTypes = enc.getDimLevelType();
+ const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
const Level cooStart = getCOOStart(enc);
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 681ba21dd4a3..9c2465d25737 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -232,7 +232,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId idx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
- if (!isUndefDLT(merger.getDimLevelType(tid, idx)))
+ if (!isUndefDLT(merger.getLvlType(tid, idx)))
return false; // used more than once
if (setLvlFormat)
@@ -243,7 +243,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
if (!isDenseDLT(dlt) && setLvlFormat) {
- assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx)));
+ assert(isUndefDLT(merger.getLvlType(tid, filterLdx)));
// Use a filter loop for sparse affine expression.
merger.setLevelAndType(tid, filterLdx, lvl, dlt);
++filterLdx;
@@ -287,7 +287,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
- if (!isUndefDLT(merger.getDimLevelType(tensor, ldx)))
+ if (!isUndefDLT(merger.getLvlType(tensor, ldx)))
return false; // used more than once, e.g., A[i][i]
// TODO: Generalizes the following two cases. A[i] (with trivial index
@@ -624,8 +624,7 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!ta.isa<AffineDimExpr>() &&
- !isDenseDLT(enc.getDimLevelType()[lvl]));
+ assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlTypes()[lvl]));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
// Now that the ordering of affine expression is captured by filter
// loop idx, we only need to ensure the affine ordering against filter
@@ -1922,7 +1921,7 @@ private:
//
auto srcTp = getRankedTensorType(tval);
auto dstEnc = SparseTensorEncodingAttr::get(
- getContext(), srcEnc.getDimLevelType(),
+ getContext(), srcEnc.getLvlTypes(),
permute(env, env.op().getMatchingIndexingMap(t)), // new order
srcEnc.getHigherOrdering(), srcEnc.getPosWidth(),
srcEnc.getCrdWidth());