summaryrefslogtreecommitdiff
path: root/mlir/tools
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2023-02-26 10:46:01 -0500
committerMehdi Amini <joker.eph@gmail.com>2023-05-01 23:16:34 -0700
commit5e118f933b6590cecd7f1afb30845a1594bc4a5d (patch)
treedeb16795854810d0f1ee2b8a40dbeeee6a1647c0 /mlir/tools
parent8f966cedea594d9a91e585e88a80a42c04049e6c (diff)
downloadllvm-5e118f933b6590cecd7f1afb30845a1594bc4a5d.tar.gz
Introduce MLIR Op Properties
This new features enabled to dedicate custom storage inline within operations. This storage can be used as an alternative to attributes to store data that is specific to an operation. Attribute can also be stored inside the properties storage if desired, but any kind of data can be present as well. This offers a way to store and mutate data without uniquing in the Context like Attribute. See the OpPropertiesTest.cpp for an example where a struct with a std::vector<> is attached to an operation and mutated in-place: struct TestProperties { int a = -1; float b = -1.; std::vector<int64_t> array = {-33}; }; More complex scheme (including reference-counting) are also possible. The only constraint to enable storing a C++ object as "properties" on an operation is to implement three functions: - convert from the candidate object to an Attribute - convert from the Attribute to the candidate object - hash the object Optional the parsing and printing can also be customized with 2 extra functions. A new options is introduced to ODS to allow dialects to specify: let usePropertiesForAttributes = 1; When set to true, the inherent attributes for all the ops in this dialect will be using properties instead of being stored alongside discardable attributes. The TestDialect showcases this feature. Another change is that we introduce new APIs on the Operation class to access separately the inherent attributes from the discardable ones. We envision deprecating and removing the `getAttr()`, `getAttrsDictionary()`, and other similar method which don't make the distinction explicit, leading to an entirely separate namespace for discardable attributes. Recommit d572cd1b067f after fixing python bindings build. Differential Revision: https://reviews.llvm.org/D141742
Diffstat (limited to 'mlir/tools')
-rw-r--r--mlir/tools/mlir-tblgen/FormatGen.cpp1
-rw-r--r--mlir/tools/mlir-tblgen/FormatGen.h2
-rw-r--r--mlir/tools/mlir-tblgen/OpClass.cpp3
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp744
-rw-r--r--mlir/tools/mlir-tblgen/OpFormatGen.cpp259
5 files changed, 840 insertions, 169 deletions
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 7d2e03ecfe27..b4f71fb45b37 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -177,6 +177,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
StringSwitch<FormatToken::Kind>(str)
.Case("attr-dict", FormatToken::kw_attr_dict)
.Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
+ .Case("prop-dict", FormatToken::kw_prop_dict)
.Case("custom", FormatToken::kw_custom)
.Case("functional-type", FormatToken::kw_functional_type)
.Case("oilist", FormatToken::kw_oilist)
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index e5fd04a24b2f..da30e35f1353 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -60,6 +60,7 @@ public:
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
+ kw_prop_dict,
kw_custom,
kw_functional_type,
kw_oilist,
@@ -287,6 +288,7 @@ public:
/// These are the kinds of directives.
enum Kind {
AttrDict,
+ PropDict,
Custom,
FunctionalType,
OIList,
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 40b688f2b96c..698569c790e9 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -37,5 +37,6 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
void OpClass::finalize() {
Class::finalize();
declare<VisibilityDeclaration>(Visibility::Public);
- declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
+ declare<ExtraClassDeclaration>(extraClassDeclaration.str(),
+ extraClassDefinition);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index e3e7ae087085..dc257fd97b0e 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -14,20 +14,25 @@
#include "OpClass.h"
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/SideEffects.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -43,6 +48,10 @@ static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
+static const char *const propertyStorage = "propStorage";
+static const char *const propertyValue = "propValue";
+static const char *const propertyAttr = "propAttr";
+static const char *const propertyDiag = "propDiag";
/// The names of the implicit attributes that contain variadic operand and
/// result segment sizes.
@@ -103,7 +112,7 @@ static const char *const attrSizedSegmentValueRangeCalcCode = R"(
///
/// {0}: The code to get the attribute.
static const char *const adapterSegmentSizeAttrInitCode = R"(
- assert(odsAttrs && "missing segment size attribute for op");
+ assert({0} && "missing segment size attribute for op");
auto sizeAttr = {0}.cast<::mlir::DenseI32ArrayAttr>();
)";
/// The code snippet to initialize the sizes for the value range calculation.
@@ -260,6 +269,10 @@ public:
assert(attrMetadata.count(attrName) && "expected attribute metadata");
return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
+ if (hasProperties()) {
+ assert(!isNamed);
+ return os << "getProperties()." << attrName;
+ }
return os << formatv(subrangeGetAttr, getAttrName(attrName),
attr.lowerBound, attr.upperBound, getAttrRange(),
isNamed ? "Named" : "");
@@ -324,6 +337,19 @@ public:
return attrMetadata;
}
+ /// Returns whether to emit a `Properties` struct for this operation or not.
+ bool hasProperties() const {
+ if (!op.getProperties().empty())
+ return true;
+ if (!op.getDialect().usePropertiesForAttributes())
+ return false;
+ return llvm::any_of(getAttrMetadata(),
+ [](const std::pair<StringRef, AttributeMetadata> &it) {
+ return !it.second.constraint ||
+ !it.second.constraint->isDerivedAttr();
+ });
+ }
+
private:
// Compute the attribute metadata.
void computeAttrMetadata();
@@ -418,6 +444,9 @@ private:
// Generates the `getOperationName` method for this op.
void genOpNameGetter();
+ // Generates code to manage the properties, if any!
+ void genPropertiesSupport();
+
// Generates getters for the attributes.
void genAttrGetters();
@@ -642,6 +671,20 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
}
}
+// Return true if a verifier can be emitted for the attribute: it is not a
+// derived attribute, it has a predicate, its condition is not empty, and, for
+// adaptors, the condition does not reference the op.
+static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
+ if (attr.isDerivedAttr())
+ return false;
+ Pred pred = attr.getPredicate();
+ if (pred.isNull())
+ return false;
+ std::string condition = pred.getCondition();
+ return !condition.empty() &&
+ (!StringRef(condition).contains("$_op") || isEmittingForOp);
+}
+
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
//
@@ -654,9 +697,11 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
// that depend on the validity of these attributes, e.g. segment size attributes
// and operand or result getters.
// 3. Verify the constraints on all present attributes.
-static void genAttributeVerifier(
- const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
+static void
+genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
+ MethodBody &body,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool useProperties) {
if (emitHelper.getAttrMetadata().empty())
return;
@@ -691,7 +736,8 @@ static void genAttributeVerifier(
// {0}: Code to get the name of the attribute.
// {1}: The emit error prefix.
// {2}: The name of the attribute.
- const char *const findRequiredAttr = R"(while (true) {{
+ const char *const findRequiredAttr = R"(
+while (true) {{
if (namedAttrIt == namedAttrRange.end())
return {1}"requires attribute '{2}'");
if (namedAttrIt->getName() == {0}) {{
@@ -714,20 +760,6 @@ static void genAttributeVerifier(
break;
})";
- // Return true if a verifier can be emitted for the attribute: it is not a
- // derived attribute, it has a predicate, its condition is not empty, and, for
- // adaptors, the condition does not reference the op.
- const auto canEmitVerifier = [&](Attribute attr) {
- if (attr.isDerivedAttr())
- return false;
- Pred pred = attr.getPredicate();
- if (pred.isNull())
- return false;
- std::string condition = pred.getCondition();
- return !condition.empty() && (!StringRef(condition).contains("$_op") ||
- emitHelper.isEmittingForOp());
- };
-
// Emit the verifier for the attribute.
const auto emitVerifier = [&](Attribute attr, StringRef attrName,
StringRef varName) {
@@ -750,58 +782,74 @@ static void genAttributeVerifier(
return (tblgenNamePrefix + attrName).str();
};
- body.indent() << formatv("auto namedAttrRange = {0};\n",
- emitHelper.getAttrRange());
- body << "auto namedAttrIt = namedAttrRange.begin();\n";
-
- // Iterate over the attributes in sorted order. Keep track of the optional
- // attributes that may be encountered along the way.
- SmallVector<const AttributeMetadata *> optionalAttrs;
- for (const std::pair<StringRef, AttributeMetadata> &it :
- emitHelper.getAttrMetadata()) {
- const AttributeMetadata &metadata = it.second;
- if (!metadata.isRequired) {
- optionalAttrs.push_back(&metadata);
- continue;
+ body.indent();
+ if (useProperties) {
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ body << formatv(
+ "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
+ it.first);
+ const AttributeMetadata &metadata = it.second;
+ if (metadata.isRequired)
+ body << formatv(
+ "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
+ it.first, emitHelper.emitErrorPrefix());
}
+ } else {
+ body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange());
+ body << "auto namedAttrIt = namedAttrRange.begin();\n";
+
+ // Iterate over the attributes in sorted order. Keep track of the optional
+ // attributes that may be encountered along the way.
+ SmallVector<const AttributeMetadata *> optionalAttrs;
+
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ const AttributeMetadata &metadata = it.second;
+ if (!metadata.isRequired) {
+ optionalAttrs.push_back(&metadata);
+ continue;
+ }
- body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv("::mlir::Attribute {0};\n",
- getVarName(optional->attrName));
- }
- body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
- emitHelper.emitErrorPrefix(), it.first);
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv(checkOptionalAttr,
- emitHelper.getAttrName(optional->attrName),
- optional->attrName);
- }
- body << "\n ++namedAttrIt;\n}\n";
- optionalAttrs.clear();
- }
- // Get trailing optional attributes.
- if (!optionalAttrs.empty()) {
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv("::mlir::Attribute {0};\n",
- getVarName(optional->attrName));
+ body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
+ emitHelper.emitErrorPrefix(), it.first);
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
+ }
+ body << "\n ++namedAttrIt;\n}\n";
+ optionalAttrs.clear();
}
- body << checkTrailingAttrs;
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv(checkOptionalAttr,
- emitHelper.getAttrName(optional->attrName),
- optional->attrName);
+ // Get trailing optional attributes.
+ if (!optionalAttrs.empty()) {
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << checkTrailingAttrs;
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
+ }
+ body << "\n ++namedAttrIt;\n}\n";
}
- body << "\n ++namedAttrIt;\n}\n";
}
body.unindent();
- // Emit the checks for segment attributes first so that the other constraints
- // can call operand and result getters.
+ // Emit the checks for segment attributes first so that the other
+ // constraints can call operand and result getters.
genNativeTraitAttrVerifier(body, emitHelper);
+ bool isEmittingForOp = emitHelper.isEmittingForOp();
for (const auto &namedAttr : emitHelper.getOp().getAttributes())
- if (canEmitVerifier(namedAttr.attr))
+ if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp))
emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
}
@@ -834,6 +882,7 @@ OpEmitter::OpEmitter(const Operator &op,
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
+ genPropertiesSupport();
genAttrGetters();
genAttrSetters();
genOptionalAttrRemovers();
@@ -989,6 +1038,274 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx,
<< ";\n";
}
+void OpEmitter::genPropertiesSupport() {
+ if (!emitHelper.hasProperties())
+ return;
+ using ConstArgument =
+ llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
+
+ SmallVector<ConstArgument> attrOrProperties;
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
+ attrOrProperties.push_back(&it.second);
+ }
+ for (const NamedProperty &prop : op.getProperties())
+ attrOrProperties.push_back(&prop);
+ if (attrOrProperties.empty())
+ return;
+ auto &setPropMethod =
+ opClass
+ .addStaticMethod(
+ "::mlir::LogicalResult", "setPropertiesFromAttr",
+ MethodParameter("Properties &", "prop"),
+ MethodParameter("::mlir::Attribute", "attr"),
+ MethodParameter("::mlir::InFlightDiagnostic *", "diag"))
+ ->body();
+ auto &getPropMethod =
+ opClass
+ .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
+ MethodParameter("::mlir::MLIRContext *", "ctx"),
+ MethodParameter("const Properties &", "prop"))
+ ->body();
+ auto &hashMethod =
+ opClass
+ .addStaticMethod("llvm::hash_code", "computePropertiesHash",
+ MethodParameter("const Properties &", "prop"))
+ ->body();
+ auto &getInherentAttrMethod =
+ opClass
+ .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
+ MethodParameter("const Properties &", "prop"),
+ MethodParameter("llvm::StringRef", "name"))
+ ->body();
+ auto &setInherentAttrMethod =
+ opClass
+ .addStaticMethod("void", "setInherentAttr",
+ MethodParameter("Properties &", "prop"),
+ MethodParameter("llvm::StringRef", "name"),
+ MethodParameter("mlir::Attribute", "value"))
+ ->body();
+ auto &populateInherentAttrsMethod =
+ opClass
+ .addStaticMethod("void", "populateInherentAttrs",
+ MethodParameter("const Properties &", "prop"),
+ MethodParameter("::mlir::NamedAttrList &", "attrs"))
+ ->body();
+ auto &verifyInherentAttrsMethod =
+ opClass
+ .addStaticMethod(
+ "::mlir::LogicalResult", "verifyInherentAttrs",
+ MethodParameter("::mlir::OperationName", "opName"),
+ MethodParameter("::mlir::NamedAttrList &", "attrs"),
+ MethodParameter(
+ "llvm::function_ref<::mlir::InFlightDiagnostic()>",
+ "getDiag"))
+ ->body();
+
+ opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties");
+
+ // Convert the property to the attribute form.
+
+ setPropMethod << R"decl(
+ ::mlir::DictionaryAttr dict = dyn_cast<::mlir::DictionaryAttr>(attr);
+ if (!dict) {
+ if (diag)
+ *diag << "expected DictionaryAttr to set properties";
+ return failure();
+ }
+ )decl";
+ // TODO: properties might be optional as well.
+ const char *propFromAttrFmt = R"decl(;
+ {{
+ auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
+ ::mlir::InFlightDiagnostic *propDiag) {{
+ {0};
+ };
+ auto attr = dict.get("{1}");
+ if (!attr) {{
+ if (diag)
+ *diag << "expected key entry for {1} in DictionaryAttr to set "
+ "Properties.";
+ return failure();
+ }
+ if (failed(setFromAttr(prop.{1}, attr, diag))) return ::mlir::failure();
+ }
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ setPropMethod << formatv(propFromAttrFmt,
+ tgfmt(prop.getConvertFromAttributeCall(),
+ &fctx.addSubst("_attr", propertyAttr)
+ .addSubst("_storage", propertyStorage)
+ .addSubst("_diag", propertyDiag)),
+ name);
+ } else {
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ setPropMethod << formatv(R"decl(
+ {{
+ auto &propStorage = prop.{0};
+ auto attr = dict.get("{0}");
+ if (attr || /*isRequired=*/{1}) {{
+ if (!attr) {{
+ if (diag)
+ *diag << "expected key entry for {0} in DictionaryAttr to set "
+ "Properties.";
+ return failure();
+ }
+ auto convertedAttr = dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
+ if (convertedAttr) {{
+ propStorage = convertedAttr;
+ } else {{
+ if (diag)
+ *diag << "Invalid attribute `{0}` in property conversion: " << attr;
+ return failure();
+ }
+ }
+ }
+)decl",
+ name, namedAttr->isRequired);
+ }
+ }
+ setPropMethod << " return ::mlir::success();\n";
+
+ // Convert the attribute form to the property.
+
+ getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
+ << " ::mlir::Builder odsBuilder{ctx};\n";
+ const char *propToAttrFmt = R"decl(
+ {
+ const auto &propStorage = prop.{0};
+ attrs.push_back(odsBuilder.getNamedAttr("{0}",
+ {1}));
+ }
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ getPropMethod << formatv(
+ propToAttrFmt, name,
+ tgfmt(prop.getConvertToAttributeCall(),
+ &fctx.addSubst("_ctxt", "ctx")
+ .addSubst("_storage", propertyStorage)));
+ continue;
+ }
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ getPropMethod << formatv(R"decl(
+ {{
+ const auto &propStorage = prop.{0};
+ if (propStorage)
+ attrs.push_back(odsBuilder.getNamedAttr("{0}",
+ propStorage));
+ }
+)decl",
+ name);
+ }
+ getPropMethod << R"decl(
+ if (!attrs.empty())
+ return odsBuilder.getDictionaryAttr(attrs);
+ return {};
+)decl";
+
+ // Hashing for the property
+
+ const char *propHashFmt = R"decl(
+ auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
+ return {1};
+ };
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ hashMethod << formatv(propHashFmt, name,
+ tgfmt(prop.getHashPropertyCall(),
+ &fctx.addSubst("_storage", propertyStorage)));
+ }
+ }
+ hashMethod << " return llvm::hash_combine(";
+ llvm::interleaveComma(
+ attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ hashMethod << "\n hash_" << namedProperty->name << "(prop."
+ << namedProperty->name << ")";
+ return;
+ }
+ const auto *namedAttr =
+ attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ hashMethod << "\n llvm::hash_value(prop." << name
+ << ".getAsOpaquePointer())";
+ });
+ hashMethod << ");\n";
+
+ const char *getInherentAttrMethodFmt = R"decl(
+ if (name == "{0}")
+ return prop.{0};
+)decl";
+ const char *setInherentAttrMethodFmt = R"decl(
+ if (name == "{0}") {{
+ prop.{0} = dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
+ return;
+ }
+)decl";
+ const char *populateInherentAttrsMethodFmt = R"decl(
+ if (prop.{0}) attrs.append("{0}", prop.{0});
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedAttr =
+ attrOrProp.dyn_cast<const AttributeMetadata *>()) {
+ StringRef name = namedAttr->attrName;
+ getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
+ setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
+ populateInherentAttrsMethod
+ << formatv(populateInherentAttrsMethodFmt, name);
+ continue;
+ }
+ }
+ getInherentAttrMethod << " return std::nullopt;\n";
+
+ // Emit the verifiers method for backward compatibility with the generic
+ // syntax. This method verifies the constraint on the properties attributes
+ // before they are set, since dyn_cast<> will silently omit failures.
+ for (const auto &attrOrProp : attrOrProperties) {
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ if (!namedAttr || !namedAttr->constraint)
+ continue;
+ Attribute attr = *namedAttr->constraint;
+ std::optional<StringRef> constraintFn =
+ staticVerifierEmitter.getAttrConstraintFn(attr);
+ if (!constraintFn)
+ continue;
+ if (canEmitAttrVerifier(attr,
+ /*isEmittingForOp=*/false)) {
+ std::string name = op.getGetterName(namedAttr->attrName);
+ verifyInherentAttrsMethod
+ << formatv(R"(
+ {{
+ ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
+ if (attr && ::mlir::failed({1}(attr, "{2}", getDiag)))
+ return ::mlir::failure();
+ }
+)",
+ name, constraintFn, namedAttr->attrName);
+ }
+ }
+ verifyInherentAttrsMethod << " return ::mlir::success();";
+}
+
void OpEmitter::genAttrGetters() {
FmtContext fctx;
fctx.withBuilder("::mlir::Builder((*this)->getContext())");
@@ -999,9 +1316,9 @@ void OpEmitter::genAttrGetters() {
method->body() << " " << attr.getDerivedCodeBody() << "\n";
};
- // Generate named accessor with Attribute return type. This is a wrapper class
- // that allows referring to the attributes via accessors instead of having to
- // use the string interface for better compile time verification.
+ // Generate named accessor with Attribute return type. This is a wrapper
+ // class that allows referring to the attributes via accessors instead of
+ // having to use the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
Attribute attr) {
auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
@@ -1086,7 +1403,8 @@ void OpEmitter::genAttrGetters() {
body << " {" << name << "AttrName(),\n"
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
.withBuilder("odsBuilder")
- .addSubst("_ctxt", "ctx"))
+ .addSubst("_ctxt", "ctx")
+ .addSubst("_storage", "ctx"))
<< "}";
},
",\n");
@@ -1175,19 +1493,29 @@ void OpEmitter::genAttrSetters() {
void OpEmitter::genOptionalAttrRemovers() {
// Generate methods for removing optional attributes, instead of having to
// use the string interface. Enables better compile time verification.
- auto emitRemoveAttr = [&](StringRef name) {
+ auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
auto upperInitial = name.take_front().upper();
auto *method = opClass.addMethod("::mlir::Attribute",
op.getRemoverName(name) + "Attr");
if (!method)
return;
- method->body() << formatv(" return (*this)->removeAttr({0}AttrName());",
+ if (useProperties) {
+ method->body() << formatv(R"(
+ auto &attr = getProperties().{0};
+ attr = {{};
+ return attr;
+)",
+ name);
+ return;
+ }
+ method->body() << formatv("return (*this)->removeAttr({0}AttrName());",
op.getGetterName(name));
};
for (const NamedAttribute &namedAttr : op.getAttributes())
if (namedAttr.attr.isOptional())
- emitRemoveAttr(namedAttr.name);
+ emitRemoveAttr(namedAttr.name,
+ op.getDialect().usePropertiesForAttributes());
}
// Generates the code to compute the start and end index of an operand or result
@@ -1417,9 +1745,15 @@ void OpEmitter::genNamedOperandSetters() {
"::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands) {
- body << formatv(
- ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
- emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ if (emitHelper.hasProperties())
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+ "{getOperandSegmentSizesAttrName(), getProperties().{1}})",
+ i, operandSegmentAttrName);
+ else
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+ emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
}
body << ");\n";
@@ -1623,6 +1957,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, {1}.operands,
{1}.attributes.getDictionary({1}.getContext()),
+ {1}.getRawProperties(),
{1}.regions, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
@@ -1645,11 +1980,17 @@ void OpEmitter::genSeparateArgParamBuilder() {
// Automatically create the 'result_segment_sizes' attribute using
// the length of the type ranges.
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- std::string getterName = op.getGetterName(resultSegmentAttrName);
- body << " " << builderOpState << ".addAttribute(" << getterName
- << "AttrName(" << builderOpState << ".name), "
- << "odsBuilder.getDenseI32ArrayAttr({";
-
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " (" << builderOpState
+ << ".getOrAddProperties<Properties>()." << resultSegmentAttrName
+ << " = \n"
+ " odsBuilder.getDenseI32ArrayAttr({";
+ } else {
+ std::string getterName = op.getGetterName(resultSegmentAttrName);
+ body << " " << builderOpState << ".addAttribute(" << getterName
+ << "AttrName(" << builderOpState << ".name), "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ }
interleaveComma(
llvm::seq<int>(0, op.getNumResults()), body, [&](int i) {
const NamedTypeConstraint &result = op.getResult(i);
@@ -1748,6 +2089,32 @@ void OpEmitter::genPopulateDefaultAttributes() {
}))
return;
+ if (op.getDialect().usePropertiesForAttributes()) {
+ SmallVector<MethodParameter> paramList;
+ paramList.emplace_back("::mlir::OperationName", "opName");
+ paramList.emplace_back("Properties &", "properties");
+ auto *m =
+ opClass.addStaticMethod("void", "populateDefaultProperties", paramList);
+ ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
+ auto &body = m->body();
+ body.indent();
+ body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
+ for (const NamedAttribute &namedAttr : op.getAttributes()) {
+ auto &attr = namedAttr.attr;
+ if (!attr.hasDefaultValue() || attr.isOptional())
+ continue;
+ StringRef name = namedAttr.name;
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+ body << "if (!properties." << name << ")\n"
+ << " properties." << name << " = "
+ << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)))
+ << ";\n";
+ }
+ return;
+ }
+
SmallVector<MethodParameter> paramList;
paramList.emplace_back("const ::mlir::OperationName &", "opName");
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
@@ -1830,6 +2197,7 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, operands,
{1}.attributes.getDictionary({1}.getContext()),
+ {1}.getRawProperties(),
{1}.regions, inferredReturnTypes))) {{)",
opClass.getClassName(), builderOpState);
if (numVariadicResults == 0 || numNonVariadicResults != 0)
@@ -2147,6 +2515,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
operand->isOptional());
continue;
}
+ if (const auto *operand = arg.dyn_cast<NamedProperty *>()) {
+ // TODO
+ continue;
+ }
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
const Attribute &attr = namedAttr.attr;
@@ -2207,12 +2579,19 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
<< " ::llvm::SmallVector<int32_t> rangeSegments;\n"
<< " for (::mlir::ValueRange range : " << argName << ")\n"
<< " rangeSegments.push_back(range.size());\n"
- << " " << builderOpState << ".addAttribute("
- << op.getGetterName(
- operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
- << "AttrName(" << builderOpState << ".name), " << odsBuilder
- << ".getDenseI32ArrayAttr(rangeSegments));"
- << " }\n";
+ << " auto rangeAttr = " << odsBuilder
+ << ".getDenseI32ArrayAttr(rangeSegments);\n";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " " << builderOpState << ".getOrAddProperties<Properties>()."
+ << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
+ << " = rangeAttr;";
+ } else {
+ body << " " << builderOpState << ".addAttribute("
+ << op.getGetterName(
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
+ << "AttrName(" << builderOpState << ".name), rangeAttr);";
+ }
+ body << " }\n";
continue;
}
@@ -2224,9 +2603,15 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
std::string sizes = op.getGetterName(operandSegmentAttrName);
- body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
- << builderOpState << ".name), "
- << "odsBuilder.getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " (" << builderOpState << ".getOrAddProperties<Properties>()."
+ << operandSegmentAttrName << "= "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ } else {
+ body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
+ << builderOpState << ".name), "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ }
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
const NamedTypeConstraint &operand = op.getOperand(i);
if (!operand.isVariableLength()) {
@@ -2272,13 +2657,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, op.getGetterName(namedAttr.name),
- constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
+ builderOpState, namedAttr.name,
+ constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ } else {
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+ builderOpState, op.getGetterName(namedAttr.name),
+ constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ }
} else {
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, op.getGetterName(namedAttr.name),
- namedAttr.name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
+ builderOpState, namedAttr.name);
+ } else {
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+ builderOpState, op.getGetterName(namedAttr.name),
+ namedAttr.name);
+ }
}
if (emitNotNullCheck)
body.unindent() << " }\n";
@@ -2448,6 +2844,8 @@ void OpEmitter::genSideEffectInterfaceMethods() {
++operandIt;
continue;
}
+ if (arg.is<NamedProperty *>())
+ continue;
const NamedAttribute *attr = arg.get<NamedAttribute *>();
if (attr->attr.getBaseAttr().isSymbolRefAttr())
resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
@@ -2544,7 +2942,6 @@ void OpEmitter::genTypeInterfaceMethods() {
continue;
const InferredResultType &infer = op.getInferredResultType(i);
std::string typeStr;
- body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
if (infer.isArg()) {
// If this is an operand, just index into operand list to access the
// type.
@@ -2558,9 +2955,22 @@ void OpEmitter::genTypeInterfaceMethods() {
} else {
auto *attr =
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
- typeStr = ("attributes.get(\"" + attr->name +
- "\").cast<::mlir::TypedAttr>().getType()")
- .str();
+ body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
+ << " = ";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << "(properties ? properties.as<Properties *>()->"
+ << attr->name
+ << " : attributes.get(\"" + attr->name +
+ "\").dyn_cast_or_null<::mlir::TypedAttr>());\n";
+ } else {
+ body << "attributes.get(\"" + attr->name +
+ "\").dyn_cast_or_null<::mlir::TypedAttr>();\n";
+ }
+ body << " if (!odsInferredTypeAttr" << inferredTypeIdx
+ << ") return ::mlir::failure();\n";
+ typeStr =
+ ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
+ .str();
}
} else if (std::optional<StringRef> builder =
op.getResult(infer.getResultIndex())
@@ -2572,7 +2982,8 @@ void OpEmitter::genTypeInterfaceMethods() {
} else {
continue;
}
- body << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
+ << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
constructedIndices[i] = inferredTypeIdx - 1;
}
}
@@ -2615,9 +3026,11 @@ void OpEmitter::genVerifier() {
opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
auto &implBody = implMethod->body();
+ bool useProperties = emitHelper.hasProperties();
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
+ genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter,
+ useProperties);
genOperandResultVerifier(implBody, op.getOperands(), "operand");
genOperandResultVerifier(implBody, op.getResults(), "result");
@@ -3003,11 +3416,110 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/false) {
+ genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
+ bool useProperties = emitHelper.hasProperties();
+ if (useProperties) {
+ // Define the properties struct with multiple members.
+ using ConstArgument =
+ llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
+ SmallVector<ConstArgument> attrOrProperties;
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
+ attrOrProperties.push_back(&it.second);
+ }
+ for (const NamedProperty &prop : op.getProperties())
+ attrOrProperties.push_back(&prop);
+ assert(!attrOrProperties.empty());
+ std::string declarations = " struct Properties {\n";
+ llvm::raw_string_ostream os(declarations);
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ if (name.empty())
+ report_fatal_error("missing name for property");
+ std::string camelName =
+ convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
+ auto &prop = namedProperty->prop;
+ // Generate the data member using the storage type.
+ os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
+ << " " << name << "Ty " << name;
+ if (prop.hasDefaultValue())
+ os << " = " << prop.getDefaultValue();
+
+ // Emit accessors using the interface type.
+ const char *accessorFmt = R"decl(;
+ {0} get{1}() {
+ auto &propStorage = this->{2};
+ return {3};
+ }
+ void set{1}(const {0} &propValue) {
+ auto &propStorage = this->{2};
+ {4};
+ }
+)decl";
+ FmtContext fctx;
+ os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name,
+ tgfmt(prop.getConvertFromStorageCall(),
+ &fctx.addSubst("_storage", propertyStorage)),
+ tgfmt(prop.getAssignToStorageCall(),
+ &fctx.addSubst("_value", propertyValue)
+ .addSubst("_storage", propertyStorage)));
+ continue;
+ }
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const Attribute *attr = nullptr;
+ if (namedAttr->constraint)
+ attr = &*namedAttr->constraint;
+ StringRef name = namedAttr->attrName;
+ if (name.empty())
+ report_fatal_error("missing name for property attr");
+ std::string camelName =
+ convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
+ // Generate the data member using the storage type.
+ StringRef storageType;
+ if (attr) {
+ storageType = attr->getStorageType();
+ } else {
+ if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
+ report_fatal_error("unexpected AttributeMetadata");
+ }
+ // TODO: update to use native integers.
+ storageType = "::mlir::DenseI32ArrayAttr";
+ }
+ os << " using " << name << "Ty = " << storageType << ";\n"
+ << " " << name << "Ty " << name << ";\n";
+
+ // Emit accessors using the interface type.
+ if (attr) {
+ const char *accessorFmt = R"decl(
+ auto get{0}() {
+ auto &propStorage = this->{1};
+ return propStorage.{2}<{3}>();
+ }
+ void set{0}(const {3} &propValue) {
+ this->{1} = propValue;
+ }
+)decl";
+ os << formatv(accessorFmt, camelName, name,
+ attr->isOptional() || attr->hasDefaultValue()
+ ? "dyn_cast_or_null"
+ : "cast",
+ storageType);
+ }
+ }
+ os << " };\n";
+ os.flush();
+ genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
+ }
genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
- genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
"odsOpName");
+ if (useProperties)
+ genericAdaptorBase.declare<Field>("Properties", "properties");
+ genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
genericAdaptor.addTemplateParam("RangeT");
genericAdaptor.addField("RangeT", "odsOperands");
@@ -3024,9 +3536,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
+ if (useProperties)
+ paramList.emplace_back("const Properties &", "properties", "{}");
+ else
+ paramList.emplace_back("::mlir::EmptyProperties", "properties", "{}");
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
baseConstructor->addMemberInitializer("odsAttrs", "attrs");
+ if (useProperties)
+ baseConstructor->addMemberInitializer("properties", "properties");
baseConstructor->addMemberInitializer("odsRegions", "regions");
MethodBody &body = baseConstructor->body();
@@ -3037,7 +3555,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
auto *constructor = genericAdaptor.addConstructor(std::move(paramList));
- constructor->addMemberInitializer("Base", "attrs, regions");
+ constructor->addMemberInitializer("Base", "attrs, properties, regions");
constructor->addMemberInitializer("odsOperands", "values");
}
@@ -3055,8 +3573,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
- // Any invalid overlap for `getOperands` will have been diagnosed before here
- // already.
+ // Any invalid overlap for `getOperands` will have been diagnosed before
+ // here already.
if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
m->body() << " return odsOperands;";
@@ -3070,8 +3588,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr");
ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
auto &body = method->body().indent();
- body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
- << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
+ if (!useProperties)
+ body << "assert(odsAttrs && \"no attributes when constructing "
+ "adapter\");\n";
+ body << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
attr.hasDefaultValue() || attr.isOptional()
? "dyn_cast_or_null"
: "cast",
@@ -3088,6 +3608,12 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
body << "return attr;\n";
};
+ if (useProperties) {
+ auto *m = genericAdaptorBase.addInlineMethod("const Properties &",
+ "getProperties");
+ ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
+ m->body() << " return properties;";
+ }
{
auto *m =
genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes");
@@ -3124,8 +3650,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
m->body() << formatv(" return *odsRegions[{0}];", i);
}
if (numRegions > 0) {
- // Any invalid overlap for `getRegions` will have been diagnosed before here
- // already.
+ // Any invalid overlap for `getRegions` will have been diagnosed before
+ // here already.
if (auto *m =
genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions"))
m->body() << " return odsRegions;";
@@ -3142,8 +3668,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
auto *constructor =
adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
constructor->addMemberInitializer(
- adaptor.getClassName(),
- "op->getOperands(), op->getAttrDictionary(), op->getRegions()");
+ adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), "
+ "op.getProperties(), op->getRegions()");
}
// Add verification function.
@@ -3159,10 +3685,12 @@ void OpOperandAdaptorEmitter::addVerification() {
MethodParameter("::mlir::Location", "loc"));
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
+ bool useProperties = emitHelper.hasProperties();
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
+ genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter,
+ useProperties);
body << " return ::mlir::success();";
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ed84bcc049a6..e0472926078d 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -132,6 +132,14 @@ private:
bool withKeyword;
};
+/// This class represents the `prop-dict` directive. This directive represents
+/// the properties of the operation, expressed as a directionary.
+class PropDictDirective
+ : public DirectiveElementBase<DirectiveElement::PropDict> {
+public:
+ explicit PropDictDirective() = default;
+};
+
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
@@ -294,8 +302,9 @@ struct OperationFormat {
};
OperationFormat(const Operator &op)
-
- {
+ : useProperties(op.getDialect().usePropertiesForAttributes() &&
+ !op.getAttributes().empty()),
+ opCppClassName(op.getCppClassName()) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
@@ -351,6 +360,12 @@ struct OperationFormat {
/// A flag indicating if this operation has the SingleBlock trait.
bool hasSingleBlockTrait;
+ /// Indicate whether attribute are stored in properties.
+ bool useProperties;
+
+ /// The Operation class name
+ StringRef opCppClassName;
+
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
@@ -389,8 +404,7 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
- if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
- result.attributes)) {{
+ if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
return ::mlir::failure();
}
)";
@@ -400,30 +414,29 @@ const char *const attrParserCode = R"(
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const genericAttrParserCode = R"(
- if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
+ if (parser.parseAttribute({0}Attr, {1}))
return ::mlir::failure();
)";
const char *const optionalAttrParserCode = R"(
- {
- ::mlir::OptionalParseResult parseResult =
- parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
- if (parseResult.has_value() && failed(*parseResult))
- return ::mlir::failure();
- }
+ ::mlir::OptionalParseResult parseResult{0}Attr =
+ parser.parseOptionalAttribute({0}Attr, {1});
+ if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
+ return ::mlir::failure();
+ if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
)";
/// The code snippet used to generate a parser call for a symbol name attribute.
///
/// {0}: The name of the attribute.
const char *const symbolNameAttrParserCode = R"(
- if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
+ if (parser.parseSymbolName({0}Attr))
return ::mlir::failure();
)";
const char *const optionalSymbolNameAttrParserCode = R"(
// Parsing an optional symbol name doesn't fail, so no need to check the
// result.
- (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
+ (void)parser.parseOptionalSymbolName({0}Attr);
)";
/// The code snippet used to generate a parser call for an enum attribute.
@@ -434,6 +447,7 @@ const char *const optionalSymbolNameAttrParserCode = R"(
/// {3}: The constant builder call to create an attribute of the enum type.
/// {4}: The set of allowed enum keywords.
/// {5}: The error message on failure when the enum isn't present.
+/// {6}: The attribute assignment expression
const char *const enumAttrParserCode = R"(
{
::llvm::StringRef attrStr;
@@ -460,7 +474,7 @@ const char *const enumAttrParserCode = R"(
<< "{0} attribute specification: \"" << attrStr << '"';;
{0}Attr = {3};
- result.addAttribute("{0}", {0}Attr);
+ {6}
}
}
)";
@@ -572,6 +586,7 @@ const char *const inferReturnTypesParserCode = R"(
if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
result.location, result.operands,
result.attributes.getDictionary(parser.getContext()),
+ result.getRawProperties(),
result.regions, inferredReturnTypes)))
return ::mlir::failure();
result.addTypes(inferredReturnTypes);
@@ -930,7 +945,9 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
}
/// Generate the parser for a custom directive.
-static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
+static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
+ bool useProperties,
+ StringRef opCppClassName) {
body << " {\n";
// Preprocess the directive variables.
@@ -1003,9 +1020,15 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
+ var->name, opCppClassName);
+ } else {
+ body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
+ var->name);
+ }
- body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
- var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) {
@@ -1041,7 +1064,8 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
/// Generate the parser for a enum attribute.
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
- FmtContext &attrTypeCtx, bool parseAsOptional) {
+ FmtContext &attrTypeCtx, bool parseAsOptional,
+ bool useProperties, StringRef opCppClassName) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1076,46 +1100,68 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
});
errorMessageOS << "]\");";
}
+ std::string attrAssignment;
+ if (useProperties) {
+ attrAssignment =
+ formatv(" "
+ "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
+ var->name, opCppClassName);
+ } else {
+ attrAssignment =
+ formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
+ }
body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr,
- validCaseKeywordsStr, errorMessage);
+ validCaseKeywordsStr, errorMessage, attrAssignment);
}
// Generate the parser for an attribute.
static void genAttrParser(AttributeVariable *attr, MethodBody &body,
- FmtContext &attrTypeCtx, bool parseAsOptional) {
+ FmtContext &attrTypeCtx, bool parseAsOptional,
+ bool useProperties, StringRef opCppClassName) {
const NamedAttribute *var = attr->getVar();
// Check to see if we can parse this as an enum attribute.
if (canFormatEnumAttr(var))
- return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional);
+ return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
+ useProperties, opCppClassName);
// Check to see if we should parse this as a symbol name attribute.
if (shouldFormatSymbolNameAttr(var)) {
body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
: symbolNameAttrParserCode,
var->name);
- return;
- }
-
- // If this attribute has a buildable type, use that when parsing the
- // attribute.
- std::string attrTypeStr;
- if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
- llvm::raw_string_ostream os(attrTypeStr);
- os << tgfmt(*typeBuilder, &attrTypeCtx);
} else {
- attrTypeStr = "::mlir::Type{}";
+
+ // If this attribute has a buildable type, use that when parsing the
+ // attribute.
+ std::string attrTypeStr;
+ if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+ llvm::raw_string_ostream os(attrTypeStr);
+ os << tgfmt(*typeBuilder, &attrTypeCtx);
+ } else {
+ attrTypeStr = "::mlir::Type{}";
+ }
+ if (parseAsOptional) {
+ body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ } else {
+ if (attr->shouldBeQualified() ||
+ var->attr.getStorageType() == "::mlir::Attribute")
+ body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+ else
+ body << formatv(attrParserCode, var->name, attrTypeStr);
+ }
}
- if (parseAsOptional) {
- body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ if (useProperties) {
+ body << formatv(
+ " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
+ "{0}Attr;\n",
+ var->name, opCppClassName);
} else {
- if (attr->shouldBeQualified() ||
- var->attr.getStorageType() == "::mlir::Attribute")
- body << formatv(genericAttrParserCode, var->name, attrTypeStr);
- else
- body << formatv(attrParserCode, var->name, attrTypeStr);
+ body << formatv(
+ " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
+ var->name);
}
}
@@ -1170,8 +1216,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
if (!thenGroup == optional->isInverted()) {
// Add the anchor unit attribute to the operation state.
- body << " result.addAttribute(\"" << anchorAttr->getVar()->name
- << "\", parser.getBuilder().getUnitAttr());\n";
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = "
+ "parser.getBuilder().getUnitAttr();",
+ anchorAttr->getVar()->name, opCppClassName);
+ } else {
+ body << " result.addAttribute(\"" << anchorAttr->getVar()->name
+ << "\", parser.getBuilder().getUnitAttr());\n";
+ }
}
}
@@ -1190,7 +1243,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
// parsing of the rest of the elements.
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
- genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true);
+ genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
+ useProperties, opCppClassName);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (::mlir::succeeded(parser.parseOptional";
@@ -1248,8 +1302,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << formatv(oilistParserCode, lelementName);
if (AttributeVariable *unitAttrElem =
oilist->getUnitAttrParsingElement(pelement)) {
- body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
- << "\", UnitAttr::get(parser.getContext()));\n";
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = "
+ "parser.getBuilder().getUnitAttr();",
+ unitAttrElem->getVar()->name, opCppClassName);
+ } else {
+ body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
+ << "\", UnitAttr::get(parser.getContext()));\n";
+ }
} else {
for (FormatElement *el : pelement)
genElementParser(el, body, attrTypeCtx);
@@ -1275,7 +1336,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
bool parseAsOptional =
(genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
- genAttrParser(attr, body, attrTypeCtx, parseAsOptional);
+ genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
+ opCppClassName);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1311,13 +1373,27 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
/// Directives.
} else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
- body << " if (parser.parseOptionalAttrDict"
- << (attrDict->isWithKeyword() ? "WithKeyword" : "")
- << "(result.attributes))\n"
+ body.indent() << "{\n";
+ body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
+ << "if (parser.parseOptionalAttrDict"
+ << (attrDict->isWithKeyword() ? "WithKeyword" : "")
+ << "(result.attributes))\n"
+ << " return ::mlir::failure();\n";
+ if (useProperties) {
+ body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
+ "[&]() {\n"
+ << " return parser.emitError(loc) << \"'\" << "
+ "result.name.getStringRef() << \"' op \";\n"
+ << " })))\n"
+ << " return ::mlir::failure();\n";
+ }
+ body.unindent() << "}\n";
+ body.unindent();
+ } else if (auto *attrDict = dyn_cast<PropDictDirective>(element)) {
+ body << " if (parseProperties(parser, result))\n"
<< " return ::mlir::failure();\n";
} else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
- genCustomDirectiveParser(customDir, body);
-
+ genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " if (parser.parseOperandList(allOperands))\n"
@@ -1571,8 +1647,16 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- body << " result.addAttribute(\"operand_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" "
+ "result.getOrAddProperties<{0}::Properties>().operand_"
+ "segment_sizes = "
+ "(parser.getBuilder().getDenseI32ArrayAttr({{",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"operand_segment_sizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ }
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
@@ -1586,18 +1670,36 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (!operand.isVariadicOfVariadic())
continue;
- body << llvm::formatv(
- " result.addAttribute(\"{0}\", "
- "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));\n",
- operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
- operand.name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << llvm::formatv(
+ " result.getOrAddProperties<{0}::Properties>().{1} = "
+ "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
+ op.getCppClassName(),
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
+ operand.name);
+ } else {
+ body << llvm::formatv(
+ " result.addAttribute(\"{0}\", "
+ "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
+ "\n",
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
+ operand.name);
+ }
}
}
if (!allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- body << " result.addAttribute(\"result_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(
+ " "
+ "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
+ "(parser.getBuilder().getDenseI32ArrayAttr({{",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"result_segment_sizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ }
auto interleaveFn = [&](const NamedTypeConstraint &result) {
// If the result is variadic emit the parsed size.
if (result.isVariableLength())
@@ -1641,6 +1743,14 @@ const char *enumAttrBeginPrinterCode = R"(
auto caseValueStr = {1}(caseValue);
)";
+/// Generate the printer for the 'prop-dict' directive.
+static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
+ MethodBody &body) {
+ body << " _odsPrinter << \" \";\n"
+ << " printProperties(this->getContext(), _odsPrinter, "
+ "getProperties());\n";
+}
+
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
MethodBody &body, bool withKeyword) {
@@ -1898,7 +2008,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
})
.Case<AttributeVariable>([&](AttributeVariable *element) {
Attribute attr = element->getVar()->attr;
- body << "(*this)->getAttr(\"" << element->getVar()->name << "\")";
+ body << op.getGetterName(element->getVar()->name) << "Attr()";
if (attr.isOptional())
return; // done
if (attr.hasDefaultValue()) {
@@ -1906,7 +2016,8 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
// default value.
FmtContext fctx;
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
- body << " != "
+ body << " && " << op.getGetterName(element->getVar()->name)
+ << "Attr() != "
<< tgfmt(attr.getConstBuilderTemplate(), &fctx,
attr.getDefaultValue());
return;
@@ -2063,6 +2174,13 @@ void OperationFormat::genElementPrinter(FormatElement *element,
return;
}
+ // Emit the attribute dictionary.
+ if (auto *propDict = dyn_cast<PropDictDirective>(element)) {
+ genPropDictPrinter(*this, op, body);
+ lastWasPunctuation = false;
+ return;
+ }
+
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
@@ -2300,6 +2418,7 @@ private:
ConstArgument findSeenArg(StringRef name);
/// Parse the various different directives.
+ FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
bool withKeyword);
FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
@@ -2329,6 +2448,7 @@ private:
// The following are various bits of format state used for verification
// during parsing.
bool hasAttrDict = false;
+ bool hasPropDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
bool canInferResultTypes = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
@@ -2873,6 +2993,8 @@ FailureOr<FormatElement *>
OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
Context ctx) {
switch (kind) {
+ case FormatToken::kw_prop_dict:
+ return parsePropDictDirective(loc, ctx);
case FormatToken::kw_attr_dict:
return parseAttrDictDirective(loc, ctx,
/*withKeyword=*/false);
@@ -2925,6 +3047,23 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
return create<AttrDictDirective>(withKeyword);
}
+FailureOr<FormatElement *>
+OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
+ if (context == TypeDirectiveContext)
+ return emitError(loc, "'prop-dict' directive can only be used as a "
+ "top-level directive");
+
+ if (context == RefDirectiveContext)
+ llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
+ // Otherwise, this is a top-level context.
+
+ if (hasPropDict)
+ return emitError(loc, "'prop-dict' directive has already been seen");
+ hasPropDict = true;
+
+ return create<PropDictDirective>();
+}
+
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *argument : arguments) {