summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2023-05-12 20:33:49 +0000
committerPeiming Liu <peiming@google.com>2023-05-16 17:22:44 +0000
commitad469385ab37e7bd883f5e4b75459e480c0d4416 (patch)
treeedad6d4d2e86085edf09d2d86d6995ebfce55fd9 /mlir
parent689de4c6759fa810d827aee06a0ab060b01172ce (diff)
downloadllvm-ad469385ab37e7bd883f5e4b75459e480c0d4416.tar.gz
[mlir][sparse] Add a helper class to help lowering operations with/without function calls
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D150477
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h66
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp258
2 files changed, 181 insertions, 143 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 9e762892e864..caa549e45e41 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -72,6 +72,72 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
// Misc code generators and utilities.
//===----------------------------------------------------------------------===//
+/// A helper class to simplify lowering operations with/without function calls.
+template <class SubClass>
+class FuncCallOrInlineGenerator {
+public:
+ FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall)
+ : retTypes(retTypes), params(params), genCall(genCall) {}
+
+ // The main API invoked by clients, which abstracts away the details of
+ // creating function calls from clients.
+ SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) {
+ if (!genCall)
+ return genImplementation(retTypes, params, builder, loc);
+
+ // Looks up the function.
+ std::string funcName = getMangledFuncName();
+ ModuleOp module = getParentOpOf<ModuleOp>(builder);
+ MLIRContext *context = module.getContext();
+ auto result = SymbolRefAttr::get(context, funcName);
+ auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+ if (!func) {
+ // Create the function if not already exist.
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder));
+ func = builder.create<func::FuncOp>(
+ loc, funcName,
+ FunctionType::get(context, params.getTypes(), retTypes));
+ func.setPrivate();
+ // Set the insertion point to the body of the function.
+ Block *entryBB = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBB);
+ ValueRange args = entryBB->getArguments();
+ // Delegates to user to generate the actually implementation.
+ SmallVector<Value> result =
+ genImplementation(retTypes, args, builder, loc);
+ builder.create<func::ReturnOp>(loc, result);
+ }
+ // Returns the CallOp result.
+ func::CallOp call = builder.create<func::CallOp>(loc, func, params);
+ return call.getResults();
+ }
+
+private:
+ template <class OpTp>
+ OpTp getParentOpOf(OpBuilder &builder) {
+ return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>();
+ }
+
+ // CRTP: get the mangled function name (only called when genCall=true).
+ std::string getMangledFuncName() {
+ return static_cast<SubClass *>(this)->getMangledFuncName();
+ }
+
+ // CRTP: Client implementation.
+ SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params,
+ OpBuilder &builder, Location loc) {
+ return static_cast<SubClass *>(this)->genImplementation(retTypes, params,
+ builder, loc);
+ }
+
+private:
+ TypeRange retTypes; // The types of all returned results
+ ValueRange params; // The values of all input parameters
+ bool genCall; // Should the implemetantion be wrapped in a function
+};
+
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 88f79bf3e8d4..e729f725689d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -42,8 +42,6 @@ namespace {
using FuncGeneratorType =
function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
-static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
-
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
@@ -396,134 +394,102 @@ static Value genCompressed(OpBuilder &builder, Location loc,
return ifOp2.getResult(o);
}
-/// Generates code along an insertion path without the need for a "cursor".
-/// This current insertion strategy comes at the expense of some testing
-/// overhead for each insertion. The strategy will be optimized later for
-/// common insertion patterns. The current insertion strategy also assumes
-/// insertions occur in "a reasonable order" that enables building the
-/// storage scheme in an appending/inserting kind of fashion (i.e. no
-/// in-between insertions that need data movement). The implementation
-/// relies on CSE/DCE to clean up all bookkeeping that is not needed.
-///
-/// TODO: better unord/not-unique; also generalize, optimize, specialize!
-///
-static void genInsertBody(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, RankedTensorType rtp) {
- const OpBuilder::InsertionGuard insertionGuard(builder);
- Block *const entryBlock = func.addEntryBlock();
- builder.setInsertionPointToStart(entryBlock);
- const ValueRange args = entryBlock->getArguments();
- const Location loc = func.getLoc();
- const SparseTensorType stt(rtp);
- const Level lvlRank = stt.getLvlRank();
-
- // Extract fields and coordinates from args.
- SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
- MutSparseTensorDescriptor desc(rtp, fields);
- const SmallVector<Value> coords =
- llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
- Value value = args.back();
- Value parentPos = constantZero(builder, loc, builder.getIndexType());
- // Generate code for every level.
- for (Level l = 0; l < lvlRank; l++) {
- const auto dlt = stt.getLvlType(l);
- if (isCompressedDLT(dlt)) {
- // Create:
- // if (!present) {
- // coordinates[l].push_back(coords[l])
- // <update positions and prepare level l + 1>
- // }
- // positions[l] = coordinates.size() - 1
- // <insert @ positions[l] at next level l + 1>
- parentPos =
- genCompressed(builder, loc, desc, coords, value, parentPos, l);
- } else if (isSingletonDLT(dlt)) {
- // Create:
- // coordinates[l].push_back(coords[l])
- // positions[l] = positions[l-1]
- // <insert @ positions[l] at next level l + 1>
- createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
- coords[l]);
- } else {
- assert(isDenseDLT(dlt));
- // Construct the new position as:
- // positions[l] = size * positions[l-1] + coords[l]
- // <insert @ positions[l] at next level l + 1>
- Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
- Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
- parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
+/// Helper class to help lowering sparse_tensor.insert operation.
+class SparseInsertGenerator
+ : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
+public:
+ SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
+ bool genCall)
+ : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
+
+ /// Generates code along an insertion path without the need for a "cursor".
+ /// This current insertion strategy comes at the expense of some testing
+ /// overhead for each insertion. The strategy will be optimized later for
+ /// common insertion patterns. The current insertion strategy also assumes
+ /// insertions occur in "a reasonable order" that enables building the
+ /// storage scheme in an appending/inserting kind of fashion (i.e. no
+ /// in-between insertions that need data movement). The implementation
+ /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
+ ///
+ /// TODO: better unord/not-unique; also generalize, optimize, specialize!
+ SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
+ OpBuilder &builder, Location loc) {
+ const SparseTensorType stt(rtp.cast<RankedTensorType>());
+ const Level lvlRank = stt.getLvlRank();
+ // Extract fields and coordinates from args.
+ SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
+ MutSparseTensorDescriptor desc(stt, fields);
+ const SmallVector<Value> coords =
+ llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
+ Value value = args.back();
+ Value parentPos = constantZero(builder, loc, builder.getIndexType());
+ // Generate code for every level.
+ for (Level l = 0; l < lvlRank; l++) {
+ const auto dlt = stt.getLvlType(l);
+ if (isCompressedDLT(dlt)) {
+ // Create:
+ // if (!present) {
+ // coordinates[l].push_back(coords[l])
+ // <update positions and prepare level l + 1>
+ // }
+ // positions[l] = coordinates.size() - 1
+ // <insert @ positions[l] at next level l + 1>
+ parentPos =
+ genCompressed(builder, loc, desc, coords, value, parentPos, l);
+ } else if (isSingletonDLT(dlt)) {
+ // Create:
+ // coordinates[l].push_back(coords[l])
+ // positions[l] = positions[l-1]
+ // <insert @ positions[l] at next level l + 1>
+ createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
+ coords[l]);
+ } else {
+ assert(isDenseDLT(dlt));
+ // Construct the new position as:
+ // positions[l] = size * positions[l-1] + coords[l]
+ // <insert @ positions[l] at next level l + 1>
+ Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+ Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
+ parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
+ }
}
+ // Reached the actual value append/insert.
+ if (!stt.isDenseLvl(lvlRank - 1))
+ createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
+ std::nullopt, value);
+ else
+ genStore(builder, loc, value, desc.getValMemRef(), parentPos);
+ return fields;
}
- // Reached the actual value append/insert.
- if (!stt.isDenseLvl(lvlRank - 1))
- createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
- std::nullopt, value);
- else
- genStore(builder, loc, value, desc.getValMemRef(), parentPos);
- builder.create<func::ReturnOp>(loc, fields);
-}
-/// Generates a call to a function to perform an insertion operation. If the
-/// function doesn't exist yet, call `createFunc` to generate the function.
-static void genInsertionCallHelper(OpBuilder &builder,
- MutSparseTensorDescriptor desc,
- SmallVectorImpl<Value> &lcvs, Value value,
- func::FuncOp insertPoint,
- StringRef namePrefix,
- FuncGeneratorType createFunc) {
- // The mangled name of the function has this format:
- // <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
- const SparseTensorType stt(desc.getRankedTensorType());
- SmallString<32> nameBuffer;
- llvm::raw_svector_ostream nameOstream(nameBuffer);
- nameOstream << namePrefix;
- const Level lvlRank = stt.getLvlRank();
- assert(lcvs.size() == static_cast<size_t>(lvlRank));
- for (Level l = 0; l < lvlRank; l++)
- nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
- // Static dim sizes are used in the generated code while dynamic sizes are
- // loaded from the dimSizes buffer. This is the reason for adding the shape
- // to the function name.
- for (const auto sh : stt.getDimShape())
- nameOstream << sh << "_";
- // Permutation information is also used in generating insertion.
- if (!stt.isIdentity())
- nameOstream << stt.getDimToLvlMap() << "_";
- nameOstream << stt.getElementType() << "_";
- nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
-
- // Look up the function.
- ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
- MLIRContext *context = module.getContext();
- auto result = SymbolRefAttr::get(context, nameOstream.str());
- auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
-
- // Construct operands: fields, coords, and value.
- SmallVector<Value> operands = llvm::to_vector(desc.getFields());
- operands.append(lcvs);
- operands.push_back(value);
- Location loc = insertPoint.getLoc();
-
- if (!func) {
- // Create the function.
- OpBuilder::InsertionGuard insertionGuard(builder);
- builder.setInsertionPoint(insertPoint);
-
- func = builder.create<func::FuncOp>(
- loc, nameOstream.str(),
- FunctionType::get(context, ValueRange(operands).getTypes(),
- ValueRange(desc.getFields()).getTypes()));
- func.setPrivate();
- createFunc(builder, module, func, stt);
+ std::string getMangledFuncName() {
+ // The mangled name of the function has this format:
+ // <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
+ constexpr const char kInsertFuncNamePrefix[] = "_insert_";
+ const SparseTensorType stt(rtp.cast<RankedTensorType>());
+
+ SmallString<32> nameBuffer;
+ llvm::raw_svector_ostream nameOstream(nameBuffer);
+ nameOstream << kInsertFuncNamePrefix;
+ const Level lvlRank = stt.getLvlRank();
+ for (Level l = 0; l < lvlRank; l++)
+ nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
+ // Static dim sizes are used in the generated code while dynamic sizes are
+ // loaded from the dimSizes buffer. This is the reason for adding the shape
+ // to the function name.
+ for (const auto sh : stt.getDimShape())
+ nameOstream << sh << "_";
+ // Permutation information is also used in generating insertion.
+ if (!stt.isIdentity())
+ nameOstream << stt.getDimToLvlMap() << "_";
+ nameOstream << stt.getElementType() << "_";
+ nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
+ return nameOstream.str().str();
}
- // Generate a call to perform the insertion and update `fields` with values
- // returned from the call.
- func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
- for (size_t i = 0, e = desc.getNumFields(); i < e; i++) {
- desc.getFields()[i] = call.getResult(i);
- }
-}
+private:
+ TensorType rtp;
+};
/// Generations insertion finalization code.
static void genEndInsert(OpBuilder &builder, Location loc,
@@ -936,8 +902,7 @@ public:
Value count = adaptor.getCount();
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();
- // Prepare level-coords.
- SmallVector<Value> lcvs(adaptor.getLvlCoords());
+
// If the innermost level is ordered, we need to sort the coordinates
// in the "added" array prior to applying the compression.
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
@@ -960,16 +925,22 @@ public:
// }
scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
Value i = loop.getInductionVar();
+
Value crd = genLoad(rewriter, loc, added, i);
Value value = genLoad(rewriter, loc, values, crd);
- lcvs.push_back(crd);
- // TODO: faster for subsequent insertions?
- auto insertPoint = op->template getParentOfType<func::FuncOp>();
- genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
- kInsertFuncNamePrefix, genInsertBody);
+ SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
+ SmallVector<Type> flatSpTensorTps = llvm::to_vector(
+ llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
+ params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+ params.push_back(crd);
+ params.push_back(value);
+ SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+ params, /*genCall=*/true);
+ SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
- rewriter.create<scf::YieldOp>(loc, desc.getFields());
+ rewriter.create<scf::YieldOp>(loc, insertRet);
+
rewriter.setInsertionPointAfter(loop);
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
@@ -991,17 +962,18 @@ public:
LogicalResult
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
- SmallVector<Value> lcvs(adaptor.getLvlCoords());
- // Generate insertion.
- Value value = adaptor.getValue();
- auto insertPoint = op->template getParentOfType<func::FuncOp>();
- genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
- kInsertFuncNamePrefix, genInsertBody);
-
+ Location loc = op.getLoc();
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ TypeRange flatSpTensorTps = desc.getFields().getTypes();
+ SmallVector<Value> params = llvm::to_vector(desc.getFields());
+ params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+ params.push_back(adaptor.getValue());
+ SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+ params, /*genCall=*/true);
+ SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOp(op,
+ genTuple(rewriter, loc, op.getTensor().getType(), ret));
return success();
}
};