summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h66
1 files changed, 66 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 9e762892e864..caa549e45e41 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -72,6 +72,72 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
// Misc code generators and utilities.
//===----------------------------------------------------------------------===//
+/// A helper class to simplify lowering operations with/without function calls.
+template <class SubClass>
+class FuncCallOrInlineGenerator {
+public:
+ FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall)
+ : retTypes(retTypes), params(params), genCall(genCall) {}
+
+ // The main API invoked by clients, which abstracts away the details of
+ // creating function calls from clients.
+ SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) {
+ if (!genCall)
+ return genImplementation(retTypes, params, builder, loc);
+
+ // Looks up the function.
+ std::string funcName = getMangledFuncName();
+ ModuleOp module = getParentOpOf<ModuleOp>(builder);
+ MLIRContext *context = module.getContext();
+ auto result = SymbolRefAttr::get(context, funcName);
+ auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+ if (!func) {
+ // Create the function if not already exist.
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder));
+ func = builder.create<func::FuncOp>(
+ loc, funcName,
+ FunctionType::get(context, params.getTypes(), retTypes));
+ func.setPrivate();
+ // Set the insertion point to the body of the function.
+ Block *entryBB = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBB);
+ ValueRange args = entryBB->getArguments();
+ // Delegates to user to generate the actually implementation.
+ SmallVector<Value> result =
+ genImplementation(retTypes, args, builder, loc);
+ builder.create<func::ReturnOp>(loc, result);
+ }
+ // Returns the CallOp result.
+ func::CallOp call = builder.create<func::CallOp>(loc, func, params);
+ return call.getResults();
+ }
+
+private:
+ template <class OpTp>
+ OpTp getParentOpOf(OpBuilder &builder) {
+ return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>();
+ }
+
+ // CRTP: get the mangled function name (only called when genCall=true).
+ std::string getMangledFuncName() {
+ return static_cast<SubClass *>(this)->getMangledFuncName();
+ }
+
+ // CRTP: Client implementation.
+ SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params,
+ OpBuilder &builder, Location loc) {
+ return static_cast<SubClass *>(this)->genImplementation(retTypes, params,
+ builder, loc);
+ }
+
+private:
+ TypeRange retTypes; // The types of all returned results
+ ValueRange params; // The values of all input parameters
+ bool genCall; // Should the implemetantion be wrapped in a function
+};
+
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);