From ad469385ab37e7bd883f5e4b75459e480c0d4416 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 12 May 2023 20:33:49 +0000 Subject: [mlir][sparse] Add a helper class to help lowering operations with/without function calls Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D150477 --- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 66 ++++++ .../Transforms/SparseTensorCodegen.cpp | 258 +++++++++------------ 2 files changed, 181 insertions(+), 143 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 9e762892e864..caa549e45e41 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -72,6 +72,72 @@ StringRef primaryTypeFunctionSuffix(Type elemTp); // Misc code generators and utilities. //===----------------------------------------------------------------------===// +/// A helper class to simplify lowering operations with/without function calls. +template +class FuncCallOrInlineGenerator { +public: + FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall) + : retTypes(retTypes), params(params), genCall(genCall) {} + + // The main API invoked by clients, which abstracts away the details of + // creating function calls from clients. + SmallVector genCallOrInline(OpBuilder &builder, Location loc) { + if (!genCall) + return genImplementation(retTypes, params, builder, loc); + + // Looks up the function. + std::string funcName = getMangledFuncName(); + ModuleOp module = getParentOpOf(builder); + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, funcName); + auto func = module.lookupSymbol(result.getAttr()); + + if (!func) { + // Create the function if not already exist. + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(getParentOpOf(builder)); + func = builder.create( + loc, funcName, + FunctionType::get(context, params.getTypes(), retTypes)); + func.setPrivate(); + // Set the insertion point to the body of the function. + Block *entryBB = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBB); + ValueRange args = entryBB->getArguments(); + // Delegates to user to generate the actually implementation. + SmallVector result = + genImplementation(retTypes, args, builder, loc); + builder.create(loc, result); + } + // Returns the CallOp result. + func::CallOp call = builder.create(loc, func, params); + return call.getResults(); + } + +private: + template + OpTp getParentOpOf(OpBuilder &builder) { + return builder.getInsertionBlock()->getParent()->getParentOfType(); + } + + // CRTP: get the mangled function name (only called when genCall=true). + std::string getMangledFuncName() { + return static_cast(this)->getMangledFuncName(); + } + + // CRTP: Client implementation. + SmallVector genImplementation(TypeRange retTypes, ValueRange params, + OpBuilder &builder, Location loc) { + return static_cast(this)->genImplementation(retTypes, params, + builder, loc); + } + +private: + TypeRange retTypes; // The types of all returned results + ValueRange params; // The values of all input parameters + bool genCall; // Should the implemetantion be wrapped in a function +}; + /// Add type casting between arith and index types when needed. Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 88f79bf3e8d4..e729f725689d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -42,8 +42,6 @@ namespace { using FuncGeneratorType = function_ref; -static constexpr const char kInsertFuncNamePrefix[] = "_insert_"; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -396,134 +394,102 @@ static Value genCompressed(OpBuilder &builder, Location loc, return ifOp2.getResult(o); } -/// Generates code along an insertion path without the need for a "cursor". -/// This current insertion strategy comes at the expense of some testing -/// overhead for each insertion. The strategy will be optimized later for -/// common insertion patterns. The current insertion strategy also assumes -/// insertions occur in "a reasonable order" that enables building the -/// storage scheme in an appending/inserting kind of fashion (i.e. no -/// in-between insertions that need data movement). The implementation -/// relies on CSE/DCE to clean up all bookkeeping that is not needed. -/// -/// TODO: better unord/not-unique; also generalize, optimize, specialize! -/// -static void genInsertBody(OpBuilder &builder, ModuleOp module, - func::FuncOp func, RankedTensorType rtp) { - const OpBuilder::InsertionGuard insertionGuard(builder); - Block *const entryBlock = func.addEntryBlock(); - builder.setInsertionPointToStart(entryBlock); - const ValueRange args = entryBlock->getArguments(); - const Location loc = func.getLoc(); - const SparseTensorType stt(rtp); - const Level lvlRank = stt.getLvlRank(); - - // Extract fields and coordinates from args. - SmallVector fields = llvm::to_vector(args.drop_back(lvlRank + 1)); - MutSparseTensorDescriptor desc(rtp, fields); - const SmallVector coords = - llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); - Value value = args.back(); - Value parentPos = constantZero(builder, loc, builder.getIndexType()); - // Generate code for every level. - for (Level l = 0; l < lvlRank; l++) { - const auto dlt = stt.getLvlType(l); - if (isCompressedDLT(dlt)) { - // Create: - // if (!present) { - // coordinates[l].push_back(coords[l]) - // - // } - // positions[l] = coordinates.size() - 1 - // - parentPos = - genCompressed(builder, loc, desc, coords, value, parentPos, l); - } else if (isSingletonDLT(dlt)) { - // Create: - // coordinates[l].push_back(coords[l]) - // positions[l] = positions[l-1] - // - createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l, - coords[l]); - } else { - assert(isDenseDLT(dlt)); - // Construct the new position as: - // positions[l] = size * positions[l-1] + coords[l] - // - Value size = sizeFromTensorAtLvl(builder, loc, desc, l); - Value mult = builder.create(loc, size, parentPos); - parentPos = builder.create(loc, mult, coords[l]); +/// Helper class to help lowering sparse_tensor.insert operation. +class SparseInsertGenerator + : public FuncCallOrInlineGenerator { +public: + SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, + bool genCall) + : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; + + /// Generates code along an insertion path without the need for a "cursor". + /// This current insertion strategy comes at the expense of some testing + /// overhead for each insertion. The strategy will be optimized later for + /// common insertion patterns. The current insertion strategy also assumes + /// insertions occur in "a reasonable order" that enables building the + /// storage scheme in an appending/inserting kind of fashion (i.e. no + /// in-between insertions that need data movement). The implementation + /// relies on CSE/DCE to clean up all bookkeeping that is not needed. + /// + /// TODO: better unord/not-unique; also generalize, optimize, specialize! + SmallVector genImplementation(TypeRange retTypes, ValueRange args, + OpBuilder &builder, Location loc) { + const SparseTensorType stt(rtp.cast()); + const Level lvlRank = stt.getLvlRank(); + // Extract fields and coordinates from args. + SmallVector fields = llvm::to_vector(args.drop_back(lvlRank + 1)); + MutSparseTensorDescriptor desc(stt, fields); + const SmallVector coords = + llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); + Value value = args.back(); + Value parentPos = constantZero(builder, loc, builder.getIndexType()); + // Generate code for every level. + for (Level l = 0; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { + // Create: + // if (!present) { + // coordinates[l].push_back(coords[l]) + // + // } + // positions[l] = coordinates.size() - 1 + // + parentPos = + genCompressed(builder, loc, desc, coords, value, parentPos, l); + } else if (isSingletonDLT(dlt)) { + // Create: + // coordinates[l].push_back(coords[l]) + // positions[l] = positions[l-1] + // + createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l, + coords[l]); + } else { + assert(isDenseDLT(dlt)); + // Construct the new position as: + // positions[l] = size * positions[l-1] + coords[l] + // + Value size = sizeFromTensorAtLvl(builder, loc, desc, l); + Value mult = builder.create(loc, size, parentPos); + parentPos = builder.create(loc, mult, coords[l]); + } } + // Reached the actual value append/insert. + if (!stt.isDenseLvl(lvlRank - 1)) + createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, + std::nullopt, value); + else + genStore(builder, loc, value, desc.getValMemRef(), parentPos); + return fields; } - // Reached the actual value append/insert. - if (!stt.isDenseLvl(lvlRank - 1)) - createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, - std::nullopt, value); - else - genStore(builder, loc, value, desc.getValMemRef(), parentPos); - builder.create(loc, fields); -} -/// Generates a call to a function to perform an insertion operation. If the -/// function doesn't exist yet, call `createFunc` to generate the function. -static void genInsertionCallHelper(OpBuilder &builder, - MutSparseTensorDescriptor desc, - SmallVectorImpl &lcvs, Value value, - func::FuncOp insertPoint, - StringRef namePrefix, - FuncGeneratorType createFunc) { - // The mangled name of the function has this format: - // ______ - const SparseTensorType stt(desc.getRankedTensorType()); - SmallString<32> nameBuffer; - llvm::raw_svector_ostream nameOstream(nameBuffer); - nameOstream << namePrefix; - const Level lvlRank = stt.getLvlRank(); - assert(lcvs.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) - nameOstream << toMLIRString(stt.getLvlType(l)) << "_"; - // Static dim sizes are used in the generated code while dynamic sizes are - // loaded from the dimSizes buffer. This is the reason for adding the shape - // to the function name. - for (const auto sh : stt.getDimShape()) - nameOstream << sh << "_"; - // Permutation information is also used in generating insertion. - if (!stt.isIdentity()) - nameOstream << stt.getDimToLvlMap() << "_"; - nameOstream << stt.getElementType() << "_"; - nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); - - // Look up the function. - ModuleOp module = insertPoint->getParentOfType(); - MLIRContext *context = module.getContext(); - auto result = SymbolRefAttr::get(context, nameOstream.str()); - auto func = module.lookupSymbol(result.getAttr()); - - // Construct operands: fields, coords, and value. - SmallVector operands = llvm::to_vector(desc.getFields()); - operands.append(lcvs); - operands.push_back(value); - Location loc = insertPoint.getLoc(); - - if (!func) { - // Create the function. - OpBuilder::InsertionGuard insertionGuard(builder); - builder.setInsertionPoint(insertPoint); - - func = builder.create( - loc, nameOstream.str(), - FunctionType::get(context, ValueRange(operands).getTypes(), - ValueRange(desc.getFields()).getTypes())); - func.setPrivate(); - createFunc(builder, module, func, stt); + std::string getMangledFuncName() { + // The mangled name of the function has this format: + // ______ + constexpr const char kInsertFuncNamePrefix[] = "_insert_"; + const SparseTensorType stt(rtp.cast()); + + SmallString<32> nameBuffer; + llvm::raw_svector_ostream nameOstream(nameBuffer); + nameOstream << kInsertFuncNamePrefix; + const Level lvlRank = stt.getLvlRank(); + for (Level l = 0; l < lvlRank; l++) + nameOstream << toMLIRString(stt.getLvlType(l)) << "_"; + // Static dim sizes are used in the generated code while dynamic sizes are + // loaded from the dimSizes buffer. This is the reason for adding the shape + // to the function name. + for (const auto sh : stt.getDimShape()) + nameOstream << sh << "_"; + // Permutation information is also used in generating insertion. + if (!stt.isIdentity()) + nameOstream << stt.getDimToLvlMap() << "_"; + nameOstream << stt.getElementType() << "_"; + nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); + return nameOstream.str().str(); } - // Generate a call to perform the insertion and update `fields` with values - // returned from the call. - func::CallOp call = builder.create(loc, func, operands); - for (size_t i = 0, e = desc.getNumFields(); i < e; i++) { - desc.getFields()[i] = call.getResult(i); - } -} +private: + TensorType rtp; +}; /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, @@ -936,8 +902,7 @@ public: Value count = adaptor.getCount(); const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); - // Prepare level-coords. - SmallVector lcvs(adaptor.getLvlCoords()); + // If the innermost level is ordered, we need to sort the coordinates // in the "added" array prior to applying the compression. if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) @@ -960,16 +925,22 @@ public: // } scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); Value i = loop.getInductionVar(); + Value crd = genLoad(rewriter, loc, added, i); Value value = genLoad(rewriter, loc, values, crd); - lcvs.push_back(crd); - // TODO: faster for subsequent insertions? - auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint, - kInsertFuncNamePrefix, genInsertBody); + SmallVector params(desc.getFields().begin(), desc.getFields().end()); + SmallVector flatSpTensorTps = llvm::to_vector( + llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); + params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); + params.push_back(crd); + params.push_back(value); + SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, + params, /*genCall=*/true); + SmallVector insertRet = insertGen.genCallOrInline(rewriter, loc); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd); - rewriter.create(loc, desc.getFields()); + rewriter.create(loc, insertRet); + rewriter.setInsertionPointAfter(loop); Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. @@ -991,17 +962,18 @@ public: LogicalResult matchAndRewrite(InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); - SmallVector lcvs(adaptor.getLvlCoords()); - // Generate insertion. - Value value = adaptor.getValue(); - auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint, - kInsertFuncNamePrefix, genInsertBody); - + Location loc = op.getLoc(); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + TypeRange flatSpTensorTps = desc.getFields().getTypes(); + SmallVector params = llvm::to_vector(desc.getFields()); + params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); + params.push_back(adaptor.getValue()); + SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, + params, /*genCall=*/true); + SmallVector ret = insertGen.genCallOrInline(rewriter, loc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); + rewriter.replaceOp(op, + genTuple(rewriter, loc, op.getTensor().getType(), ret)); return success(); } }; -- cgit v1.2.1