diff options
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 9e762892e864..caa549e45e41 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -72,6 +72,72 @@ StringRef primaryTypeFunctionSuffix(Type elemTp); // Misc code generators and utilities. //===----------------------------------------------------------------------===// +/// A helper class to simplify lowering operations with/without function calls. +template <class SubClass> +class FuncCallOrInlineGenerator { +public: + FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall) + : retTypes(retTypes), params(params), genCall(genCall) {} + + // The main API invoked by clients, which abstracts away the details of + // creating function calls from clients. + SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) { + if (!genCall) + return genImplementation(retTypes, params, builder, loc); + + // Looks up the function. + std::string funcName = getMangledFuncName(); + ModuleOp module = getParentOpOf<ModuleOp>(builder); + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, funcName); + auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); + + if (!func) { + // Create the function if not already exist. + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder)); + func = builder.create<func::FuncOp>( + loc, funcName, + FunctionType::get(context, params.getTypes(), retTypes)); + func.setPrivate(); + // Set the insertion point to the body of the function. + Block *entryBB = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBB); + ValueRange args = entryBB->getArguments(); + // Delegates to user to generate the actually implementation. + SmallVector<Value> result = + genImplementation(retTypes, args, builder, loc); + builder.create<func::ReturnOp>(loc, result); + } + // Returns the CallOp result. + func::CallOp call = builder.create<func::CallOp>(loc, func, params); + return call.getResults(); + } + +private: + template <class OpTp> + OpTp getParentOpOf(OpBuilder &builder) { + return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>(); + } + + // CRTP: get the mangled function name (only called when genCall=true). + std::string getMangledFuncName() { + return static_cast<SubClass *>(this)->getMangledFuncName(); + } + + // CRTP: Client implementation. + SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params, + OpBuilder &builder, Location loc) { + return static_cast<SubClass *>(this)->genImplementation(retTypes, params, + builder, loc); + } + +private: + TypeRange retTypes; // The types of all returned results + ValueRange params; // The values of all input parameters + bool genCall; // Should the implemetantion be wrapped in a function +}; + /// Add type casting between arith and index types when needed. Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); |