summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUday Bondhugula <uday@polymagelabs.com>2021-10-20 15:14:54 +0530
committerUday Bondhugula <uday@polymagelabs.com>2021-10-28 18:09:34 +0530
commit57b9b29649dacdd34ef8903d7bc5e6943188f480 (patch)
tree1357e1e7ad144cc38ae69dc72f2e351c38b767fb
parent349295fcf37ed1ff1ea98c18ea1b391741823916 (diff)
downloadllvm-57b9b29649dacdd34ef8903d7bc5e6943188f480.tar.gz
[MLIR][LLVM] Add llvm.mlir.global_ctors/dtors and translation support
Add llvm.mlir.global_ctors and global_dtors ops and their translation support to LLVM global_ctors/global_dtors global variables. Differential Revision: https://reviews.llvm.org/D112524
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td60
-rw-r--r--mlir/include/mlir/IR/OpBase.td5
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp62
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp26
-rw-r--r--mlir/test/Dialect/LLVMIR/global.mlir18
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid.mlir30
-rw-r--r--mlir/test/Target/LLVMIR/llvmir.mlir18
7 files changed, 215 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 29c7446740a3..6bd64edf44c4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1153,6 +1153,66 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
let verifier = "return ::verify(*this);";
}
+def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let arguments = (ins FlatSymbolRefArrayAttr
+ : $ctors, I32ArrayAttr
+ : $priorities);
+ let summary = "LLVM dialect global_ctors.";
+ let description = [{
+ Specifies a list of constructor functions and priorities. The functions
+ referenced by this array will be called in ascending order of priority (i.e.
+ lowest first) when the module is loaded. The order of functions with the
+ same priority is not defined. This operation is translated to LLVM's
+ global_ctors global variable. The initializer functions are run at load
+ time. The `data` field present in LLVM's global_ctors variable is not
+ modeled here.
+
+ Examples:
+
+ ```mlir
+ llvm.mlir.global_ctors {@ctor}
+
+ llvm.func @ctor() {
+ ...
+ llvm.return
+ }
+ ```
+
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let assemblyFormat = "attr-dict";
+}
+
+def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let arguments = (ins
+ FlatSymbolRefArrayAttr:$dtors,
+ I32ArrayAttr:$priorities
+ );
+ let summary = "LLVM dialect global_dtors.";
+ let description = [{
+ Specifies a list of destructor functions and priorities. The functions
+ referenced by this array will be called in descending order of priority (i.e.
+ highest first) when the module is unloaded. The order of functions with the
+ same priority is not defined. This operation is translated to LLVM's
+ global_dtors global variable. The `data` field present in LLVM's
+ global_dtors variable is not modeled here.
+
+ Examples:
+
+ ```mlir
+ llvm.func @dtor() {
+ llvm.return
+ }
+ llvm.mlir.global_dtors {@dtor}
+ ```
+
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let assemblyFormat = "attr-dict";
+}
+
def LLVM_LLVMFuncOp : LLVM_Op<"func",
[AutomaticAllocationScope, IsolatedFromAbove, FunctionLike, Symbol]> {
let summary = "LLVM dialect function.";
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 978f95c01436..ec0d5355dcf1 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1641,6 +1641,11 @@ def SymbolRefArrayAttr :
let constBuilderCall = ?;
}
+def FlatSymbolRefArrayAttr :
+ TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
+ let constBuilderCall = ?;
+}
+
//===----------------------------------------------------------------------===//
// Derive attribute kinds
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 16a4eb65f6d2..47c0fe462cf2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -70,6 +70,22 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
}
+/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
+/// fully defined llvm.func.
+static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
+ Operation *op,
+ SymbolTableCollection &symbolTable) {
+ StringRef name = symbol.getValue();
+ auto func =
+ symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
+ if (!func)
+ return op->emitOpError("'")
+ << name << "' does not reference a valid LLVM function";
+ if (func.isExternal())
+ return op->emitOpError("'") << name << "' does not have a definition";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CmpOp.
//===----------------------------------------------------------------------===//
@@ -1625,6 +1641,48 @@ static LogicalResult verify(GlobalOp op) {
}
//===----------------------------------------------------------------------===//
+// LLVM::GlobalCtorsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ for (Attribute ctor : ctors()) {
+ if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
+ symbolTable)))
+ return failure();
+ }
+ return success();
+}
+
+static LogicalResult verify(GlobalCtorsOp op) {
+ if (op.ctors().size() != op.priorities().size())
+ return op.emitError(
+ "mismatch between the number of ctors and the number of priorities");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LLVM::GlobalDtorsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ for (Attribute dtor : dtors()) {
+ if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
+ symbolTable)))
+ return failure();
+ }
+ return success();
+}
+
+static LogicalResult verify(GlobalDtorsOp op) {
+ if (op.dtors().size() != op.priorities().size())
+ return op.emitError(
+ "mismatch between the number of dtors and the number of priorities");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ShuffleVectorOp.
//===----------------------------------------------------------------------===//
// Expects vector to be of wrapped LLVM vector type and position to be of
@@ -2353,7 +2411,7 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-static constexpr const FastmathFlags FastmathFlagsList[] = {
+static constexpr const FastmathFlags fastmathFlagsList[] = {
// clang-format off
FastmathFlags::nnan,
FastmathFlags::ninf,
@@ -2368,7 +2426,7 @@ static constexpr const FastmathFlags FastmathFlagsList[] = {
void FMFAttr::print(DialectAsmPrinter &printer) const {
printer << "fastmath<";
- auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
+ auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
return bitEnumContains(this->getFlags(), flag);
});
llvm::interleaveComma(flags, printer,
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 12bdbad6b025..c044e8c6bb6e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -42,6 +42,7 @@
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace mlir;
using namespace mlir::LLVM;
@@ -556,7 +557,7 @@ static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
}
/// Create named global variables that correspond to llvm.mlir.global
-/// definitions.
+/// definitions. Convert llvm.global_ctors and global_dtors ops.
LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
llvm::Type *type = convertType(op.getType());
@@ -625,6 +626,26 @@ LogicalResult ModuleTranslation::convertGlobals() {
}
}
+ // Convert llvm.mlir.global_ctors and dtors.
+ for (Operation &op : getModuleBody(mlirModule)) {
+ auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
+ auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
+ if (!ctorOp && !dtorOp)
+ continue;
+ auto range = ctorOp ? llvm::zip(ctorOp.ctors(), ctorOp.priorities())
+ : llvm::zip(dtorOp.dtors(), dtorOp.priorities());
+ auto appendGlobalFn =
+ ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
+ for (auto symbolAndPriority : range) {
+ llvm::Function *f = lookupFunction(
+ std::get<0>(symbolAndPriority).cast<FlatSymbolRefAttr>().getValue());
+ appendGlobalFn(
+ *llvmModule.get(), f,
+ std::get<1>(symbolAndPriority).cast<IntegerAttr>().getInt(),
+ /*Data=*/nullptr);
+ }
+ }
+
return success();
}
@@ -1028,7 +1049,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
// Convert other top-level operations if possible.
llvm::IRBuilder<> llvmBuilder(llvmContext);
for (Operation &o : getModuleBody(module).getOperations()) {
- if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::MetadataOp>(&o) &&
+ if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
+ LLVM::GlobalDtorsOp, LLVM::MetadataOp>(&o) &&
!o.hasTrait<OpTrait::IsTerminator>() &&
failed(translator.convertOperation(o, llvmBuilder))) {
return nullptr;
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index fd357a718c36..b831f3eb7c01 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -209,3 +209,21 @@ func @mismatch_addr_space() {
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm.ptr<i64, 4>
}
+
+// -----
+
+llvm.func @ctor() {
+ llvm.return
+}
+
+// CHECK: llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]}
+llvm.mlir.global_ctors { ctors = [@ctor], priorities = [0 : i32]}
+
+// -----
+
+llvm.func @dtor() {
+ llvm.return
+}
+
+// CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
+llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c80be74ce900..1caa3f415ee2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -5,6 +5,36 @@ llvm.mlir.global private @invalid_global_alignment(42 : i64) {alignment = 63} :
// -----
+llvm.func @ctor() {
+ llvm.return
+}
+
+// expected-error@+1{{mismatch between the number of ctors and the number of priorities}}
+llvm.mlir.global_ctors {ctors = [@ctor], priorities = []}
+
+// -----
+
+llvm.func @dtor() {
+ llvm.return
+}
+
+// expected-error@+1{{mismatch between the number of dtors and the number of priorities}}
+llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32, 32767 : i32]}
+
+// -----
+
+// expected-error@+1{{'ctor' does not reference a valid LLVM function}}
+llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]}
+
+// -----
+
+llvm.func @dtor()
+
+// expected-error@+1{{'dtor' does not have a definition}}
+llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
+
+// -----
+
// expected-error@+1{{expected llvm.noalias argument attribute to be a unit attribute}}
func @invalid_noalias(%arg0: i32 {llvm.noalias = 3}) {
"llvm.return"() : () -> ()
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index e16c579c12aa..7677e59c3cd9 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1384,6 +1384,24 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
// -----
+// CHECK: @llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }]
+llvm.mlir.global_ctors { ctors = [@foo], priorities = [0 : i32]}
+
+llvm.func @foo() {
+ llvm.return
+}
+
+// -----
+
+// CHECK: @llvm.global_dtors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }]
+llvm.mlir.global_dtors { dtors = [@foo], priorities = [0 : i32]}
+
+llvm.func @foo() {
+ llvm.return
+}
+
+// -----
+
// Check that branch weight attributes are exported properly as metadata.
llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// CHECK: !prof ![[NODE:[0-9]+]]