summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2023-04-24 11:53:58 -0700
committerJacques Pienaar <jpienaar@google.com>2023-04-24 11:53:58 -0700
commit09115580056fc57d5cdaa1de20631a838a4ea1c4 (patch)
tree822b73e5192eedae1c3a2a9b1ae4f961dcad7a22
parent8e091b1220e09c0c1000ed676c1c92a09e871129 (diff)
downloadllvm-09115580056fc57d5cdaa1de20631a838a4ea1c4.tar.gz
[mlir] Dialect type/attr bytecode read/write generator.
Tool to help generate dialect bytecode Attribute & Type reader/writing. Show usage by flipping builtin dialect. It helps reduce boilerplate when writing dialect bytecode attribute and type readers/writers. It is not an attempt at a generic spec mechanism but rather practically focussing on boilerplate reduction while also considering that it need not be the only in memory format and make it relatively easy to change. There should be some cleanup in follow up as we expand to more dialects. Differential Revision: https://reviews.llvm.org/D144820
-rw-r--r--mlir/include/mlir/Bytecode/BytecodeImplementation.h31
-rw-r--r--mlir/include/mlir/IR/BuiltinDialectBytecode.td566
-rw-r--r--mlir/include/mlir/IR/BytecodeBase.td159
-rw-r--r--mlir/include/mlir/IR/CMakeLists.txt4
-rw-r--r--mlir/lib/IR/BuiltinDialectBytecode.cpp1181
-rw-r--r--mlir/lib/IR/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp467
-rw-r--r--mlir/tools/mlir-tblgen/CMakeLists.txt1
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel22
9 files changed, 1312 insertions, 1120 deletions
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index ea9bcad735b3..6e7b9ff26342 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -345,6 +345,37 @@ public:
}
};
+/// Helper for resource handle reading that returns LogicalResult.
+template <typename T, typename... Ts>
+static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
+ FailureOr<T> &value, Ts &&...params) {
+ FailureOr<T> handle = reader.readResourceHandle<T>();
+ if (failed(handle))
+ return failure();
+ if (auto *result = dyn_cast<T>(&*handle)) {
+ value = std::move(*result);
+ return success();
+ }
+ return failure();
+}
+
+/// Helper method that injects context only if needed, this helps unify some of
+/// the attribute construction methods.
+template <typename T, typename... Ts>
+auto get(MLIRContext *context, Ts &&...params) {
+ // Prefer a direct `get` method if one exists.
+ if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
+ (void)context;
+ return T::get(std::forward<Ts>(params)...);
+ } else if constexpr (llvm::is_detected<detail::has_get_method, T,
+ MLIRContext *, Ts...>::value) {
+ return T::get(context, std::forward<Ts>(params)...);
+ } else {
+ // Otherwise, pass to the base get.
+ return T::Base::get(context, std::forward<Ts>(params)...);
+ }
+}
+
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
new file mode 100644
index 000000000000..b59f96c9fa9f
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -0,0 +1,566 @@
+//===-- BuiltinBytecode.td - Builtin bytecode defs ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the Builtin bytecode reader/writer definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUILTIN_BYTECODE
+#define BUILTIN_BYTECODE
+
+include "mlir/IR/BytecodeBase.td"
+
+def LocationAttr : AttributeKind;
+def ShapedType: WithType<"ShapedType", Type>;
+
+def Location : CompositeBytecode {
+ dag members = (attr
+ WithGetter<"(LocationAttr)$_attrType", WithType<"LocationAttr", LocationAttr>>:$value
+ );
+ let cBuilder = "Location($_args)";
+}
+
+def String :
+ WithParser <"succeeded($_reader.readString($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeOwnedString($_getter)",
+ WithGetter <"$_attrType",
+ WithType <"StringRef">>>>>;
+
+// enum AttributeCode {
+// /// ArrayAttr {
+// /// elements: Attribute[]
+// /// }
+// ///
+// kArrayAttr = 0,
+//
+def ArrayAttr : DialectAttribute<(attr
+ Array<Attribute>:$value
+)>;
+
+let cType = "StringAttr" in {
+// /// StringAttr {
+// /// value: string
+// /// }
+// kStringAttr = 2,
+def StringAttr : DialectAttribute<(attr
+ String:$value
+)> {
+ let printerPredicate = "$_val.getType().isa<NoneType>()";
+}
+
+// /// StringAttrWithType {
+// /// value: string,
+// /// type: Type
+// /// }
+// /// A variant of StringAttr with a type.
+// kStringAttrWithType = 3,
+def StringAttrWithType : DialectAttribute<(attr
+ String:$value,
+ Type:$type
+)> { let printerPredicate = "!$_val.getType().isa<NoneType>()"; }
+}
+
+// /// DictionaryAttr {
+// /// attrs: <StringAttr, Attribute>[]
+// /// }
+// kDictionaryAttr = 1,
+def NamedAttribute : CompositeBytecode {
+ dag members = (attr
+ StringAttr:$name,
+ Attribute:$value
+ );
+ let cBuilder = "NamedAttribute($_args)";
+}
+def DictionaryAttr : DialectAttribute<(attr
+ Array<NamedAttribute>:$value
+)>;
+
+// /// FlatSymbolRefAttr {
+// /// rootReference: StringAttr
+// /// }
+// /// A variant of SymbolRefAttr with no leaf references.
+// kFlatSymbolRefAttr = 4,
+def FlatSymbolRefAttr: DialectAttribute<(attr
+ StringAttr:$rootReference
+)>;
+
+// /// SymbolRefAttr {
+// /// rootReference: StringAttr,
+// /// leafReferences: FlatSymbolRefAttr[]
+// /// }
+// kSymbolRefAttr = 5,
+def SymbolRefAttr: DialectAttribute<(attr
+ StringAttr:$rootReference,
+ Array<FlatSymbolRefAttr>:$nestedReferences
+)>;
+
+// /// TypeAttr {
+// /// value: Type
+// /// }
+// kTypeAttr = 6,
+def TypeAttr: DialectAttribute<(attr
+ Type:$value
+)>;
+
+// /// UnitAttr {
+// /// }
+// kUnitAttr = 7,
+def UnitAttr: DialectAttribute<(attr)>;
+
+// /// IntegerAttr {
+// /// type: Type
+// /// value: APInt,
+// /// }
+// kIntegerAttr = 8,
+def IntegerAttr: DialectAttribute<(attr
+ Type:$type,
+ KnownWidthAPInt<"type">:$value
+)> {
+ let cBuilder = "get<$_resultType>(context, type, *value)";
+}
+
+//
+// /// FloatAttr {
+// /// type: FloatType
+// /// value: APFloat
+// /// }
+// kFloatAttr = 9,
+defvar FloatType = Type;
+def FloatAttr : DialectAttribute<(attr
+ FloatType:$type,
+ KnownSemanticsAPFloat<"type">:$value
+)> {
+ let cBuilder = "get<$_resultType>(context, type, *value)";
+}
+
+// /// CallSiteLoc {
+// /// callee: LocationAttr,
+// /// caller: LocationAttr
+// /// }
+// kCallSiteLoc = 10,
+def CallSiteLoc : DialectAttribute<(attr
+ LocationAttr:$callee,
+ LocationAttr:$caller
+)>;
+
+// /// FileLineColLoc {
+// /// filename: StringAttr,
+// /// line: varint,
+// /// column: varint
+// /// }
+// kFileLineColLoc = 11,
+def FileLineColLoc : DialectAttribute<(attr
+ StringAttr:$filename,
+ VarInt:$line,
+ VarInt:$column
+)>;
+
+let cType = "FusedLoc",
+ cBuilder = "cast<FusedLoc>(get<FusedLoc>(context, $_args))" in {
+// /// FusedLoc {
+// /// locations: Location[]
+// /// }
+// kFusedLoc = 12,
+def FusedLoc : DialectAttribute<(attr
+ Array<Location>:$locations
+)> {
+ let printerPredicate = "!$_val.getMetadata()";
+}
+
+// /// FusedLocWithMetadata {
+// /// locations: LocationAttr[],
+// /// metadata: Attribute
+// /// }
+// /// A variant of FusedLoc with metadata.
+// kFusedLocWithMetadata = 13,
+def FusedLocWithMetadata : DialectAttribute<(attr
+ Array<Location>:$locations,
+ Attribute:$metadata
+)> {
+ let printerPredicate = "$_val.getMetadata()";
+}
+}
+
+// /// NameLoc {
+// /// name: StringAttr,
+// /// childLoc: LocationAttr
+// /// }
+// kNameLoc = 14,
+def NameLoc : DialectAttribute<(attr
+ StringAttr:$name,
+ LocationAttr:$childLoc
+)>;
+
+// /// UnknownLoc {
+// /// }
+// kUnknownLoc = 15,
+def UnknownLoc : DialectAttribute<(attr)>;
+
+// /// DenseResourceElementsAttr {
+// /// type: ShapedType,
+// /// handle: ResourceHandle
+// /// }
+// kDenseResourceElementsAttr = 16,
+def DenseResourceElementsAttr : DialectAttribute<(attr
+ ShapedType:$type,
+ ResourceHandle<"DenseResourceElementsHandle">:$rawHandle
+)> {
+ // Note: order of serialization does not match order of builder.
+ let cBuilder = "get<$_resultType>(context, type, *rawHandle)";
+}
+
+let cType = "RankedTensorType" in {
+// /// RankedTensorType {
+// /// shape: svarint[],
+// /// elementType: Type,
+// /// }
+// ///
+// kRankedTensorType = 13,
+def RankedTensorType : DialectType<(type
+ Array<SignedVarInt>:$shape,
+ Type:$elementType
+)> {
+ let printerPredicate = "!$_val.getEncoding()";
+}
+
+// /// RankedTensorTypeWithEncoding {
+// /// encoding: Attribute,
+// /// shape: svarint[],
+// /// elementType: Type
+// /// }
+// /// Variant of RankedTensorType with an encoding.
+// kRankedTensorTypeWithEncoding = 14,
+def RankedTensorTypeWithEncoding : DialectType<(type
+ Attribute:$encoding,
+ Array<SignedVarInt>:$shape,
+ Type:$elementType
+)> {
+ let printerPredicate = "$_val.getEncoding()";
+ // Note: order of serialization does not match order of builder.
+ let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)";
+}
+}
+
+// /// DenseArrayAttr {
+// /// elementType: Type,
+// /// size: varint
+// /// data: blob
+// /// }
+// kDenseArrayAttr = 17,
+def DenseArrayAttr : DialectAttribute<(attr
+ Type:$elementType,
+ VarInt:$size,
+ Blob:$rawData
+)>;
+
+// /// DenseIntOrFPElementsAttr {
+// /// type: ShapedType,
+// /// data: blob
+// /// }
+// kDenseIntOrFPElementsAttr = 18,
+def DenseElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
+def DenseIntOrFPElementsAttr : DialectAttribute<(attr
+ ShapedType:$type,
+ Blob:$rawData
+)> {
+ let cBuilder = "cast<$_resultType>($_resultType::getFromRawBuffer($_args))";
+}
+
+// /// DenseStringElementsAttr {
+// /// type: ShapedType,
+// /// isSplat: varint,
+// /// data: string[]
+// /// }
+// kDenseStringElementsAttr = 19,
+def DenseStringElementsAttr : DialectAttribute<(attr
+ ShapedType:$type,
+ WithGetter<"$_attrType.isSplat()", VarInt>:$_isSplat,
+ WithBuilder<"$_args",
+ WithType<"SmallVector<StringRef>",
+ WithParser <"succeeded(readPotentiallySplatString($_reader, type, _isSplat, $_var))",
+ WithPrinter<"writePotentiallySplatString($_writer, $_name)">>>>:$rawStringData
+)>;
+
+// /// SparseElementsAttr {
+// /// type: ShapedType,
+// /// indices: DenseIntElementsAttr,
+// /// values: DenseElementsAttr
+// /// }
+// kSparseElementsAttr = 20,
+def DenseIntElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
+def SparseElementsAttr : DialectAttribute<(attr
+ ShapedType:$type,
+ DenseIntElementsAttr:$indices,
+ DenseElementsAttr:$values
+)>;
+
+// Types
+// -----
+
+// enum TypeCode {
+// /// IntegerType {
+// /// widthAndSignedness: varint // (width << 2) | (signedness)
+// /// }
+// ///
+// kIntegerType = 0,
+def IntegerType : DialectType<(type
+ // Yes not pretty,
+ WithParser<"succeeded($_reader.readVarInt($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeVarInt($_name.getWidth() << 2 | $_name.getSignedness())",
+ WithType <"uint64_t">>>>:$_widthAndSignedness,
+ // Split up parsed varint for create method.
+ LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width,
+ LocalVar<"IntegerType::SignednessSemantics",
+ "static_cast<IntegerType::SignednessSemantics>(_widthAndSignedness & 0x3)">:$signedness
+)>;
+
+//
+// /// IndexType {
+// /// }
+// ///
+// kIndexType = 1,
+def IndexType : DialectType<(type)>;
+
+// /// FunctionType {
+// /// inputs: Type[],
+// /// results: Type[]
+// /// }
+// ///
+// kFunctionType = 2,
+def FunctionType : DialectType<(type
+ Array<Type>:$inputs,
+ Array<Type>:$results
+)>;
+
+// /// BFloat16Type {
+// /// }
+// ///
+// kBFloat16Type = 3,
+def BFloat16Type : DialectType<(type)>;
+
+// /// Float16Type {
+// /// }
+// ///
+// kFloat16Type = 4,
+def Float16Type : DialectType<(type)>;
+
+// /// Float32Type {
+// /// }
+// ///
+// kFloat32Type = 5,
+def Float32Type : DialectType<(type)>;
+
+// /// Float64Type {
+// /// }
+// ///
+// kFloat64Type = 6,
+def Float64Type : DialectType<(type)>;
+
+// /// Float80Type {
+// /// }
+// ///
+// kFloat80Type = 7,
+def Float80Type : DialectType<(type)>;
+
+// /// Float128Type {
+// /// }
+// ///
+// kFloat128Type = 8,
+def Float128Type : DialectType<(type)>;
+
+// /// ComplexType {
+// /// elementType: Type
+// /// }
+// ///
+// kComplexType = 9,
+def ComplexType : DialectType<(type
+ Type:$elementType
+)>;
+
+def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>;
+
+let cType = "MemRefType" in {
+// /// MemRefType {
+// /// shape: svarint[],
+// /// elementType: Type,
+// /// layout: Attribute
+// /// }
+// ///
+// kMemRefType = 10,
+def MemRefType : DialectType<(type
+ Array<SignedVarInt>:$shape,
+ Type:$elementType,
+ MemRefLayout:$layout
+)> {
+ let printerPredicate = "!$_val.getMemorySpace()";
+}
+
+// /// MemRefTypeWithMemSpace {
+// /// memorySpace: Attribute,
+// /// shape: svarint[],
+// /// elementType: Type,
+// /// layout: Attribute
+// /// }
+// /// Variant of MemRefType with non-default memory space.
+// kMemRefTypeWithMemSpace = 11,
+def MemRefTypeWithMemSpace : DialectType<(type
+ Attribute:$memorySpace,
+ Array<SignedVarInt>:$shape,
+ Type:$elementType,
+ MemRefLayout:$layout
+)> {
+ let printerPredicate = "!!$_val.getMemorySpace()";
+ // Note: order of serialization does not match order of builder.
+ let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)";
+}
+}
+
+// /// NoneType {
+// /// }
+// ///
+// kNoneType = 12,
+def NoneType : DialectType<(type)>;
+
+// /// TupleType {
+// /// elementTypes: Type[]
+// /// }
+// kTupleType = 15,
+def TupleType : DialectType<(type
+ Array<Type>:$types
+)>;
+
+let cType = "UnrankedMemRefType" in {
+// /// UnrankedMemRefType {
+// /// elementType: Type
+// /// }
+// ///
+// kUnrankedMemRefType = 16,
+def UnrankedMemRefType : DialectType<(type
+ Type:$elementType
+)> {
+ let printerPredicate = "!$_val.getMemorySpace()";
+ let cBuilder = "get<$_resultType>(context, elementType, Attribute())";
+}
+
+// /// UnrankedMemRefTypeWithMemSpace {
+// /// memorySpace: Attribute,
+// /// elementType: Type
+// /// }
+// /// Variant of UnrankedMemRefType with non-default memory space.
+// kUnrankedMemRefTypeWithMemSpace = 17,
+def UnrankedMemRefTypeWithMemSpace : DialectType<(type
+ Attribute:$memorySpace,
+ Type:$elementType
+)> {
+ let printerPredicate = "$_val.getMemorySpace()";
+ // Note: order of serialization does not match order of builder.
+ let cBuilder = "get<$_resultType>(context, elementType, memorySpace)";
+}
+}
+
+// /// UnrankedTensorType {
+// /// elementType: Type
+// /// }
+// ///
+// kUnrankedTensorType = 18,
+def UnrankedTensorType : DialectType<(type
+ Type:$elementType
+)>;
+
+let cType = "VectorType" in {
+// /// VectorType {
+// /// shape: svarint[],
+// /// elementType: Type
+// /// }
+// ///
+// kVectorType = 19,
+def VectorType : DialectType<(type
+ Array<SignedVarInt>:$shape,
+ Type:$elementType
+)> {
+ let printerPredicate = "!$_val.getNumScalableDims()";
+}
+
+// /// VectorTypeWithScalableDims {
+// /// numScalableDims: varint,
+// /// shape: svarint[],
+// /// elementType: Type
+// /// }
+// /// Variant of VectorType with scalable dimensions.
+// kVectorTypeWithScalableDims = 20,
+def VectorTypeWithScalableDims : DialectType<(type
+ VarInt:$numScalableDims,
+ Array<SignedVarInt>:$shape,
+ Type:$elementType
+)> {
+ let printerPredicate = "$_val.getNumScalableDims()";
+ // Note: order of serialization does not match order of builder.
+ let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)";
+}
+}
+
+/// This enum contains marker codes used to indicate which attribute is
+/// currently being decoded, and how it should be decoded. The order of these
+/// codes should generally be unchanged, as any changes will inevitably break
+/// compatibility with older bytecode.
+
+def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
+ let elems = [
+ ArrayAttr,
+ DictionaryAttr,
+ StringAttr,
+ StringAttrWithType,
+ FlatSymbolRefAttr,
+ SymbolRefAttr,
+ TypeAttr,
+ UnitAttr,
+ IntegerAttr,
+ FloatAttr,
+ CallSiteLoc,
+ FileLineColLoc,
+ FusedLoc,
+ FusedLocWithMetadata,
+ NameLoc,
+ UnknownLoc,
+ DenseResourceElementsAttr,
+ DenseArrayAttr,
+ DenseIntOrFPElementsAttr,
+ DenseStringElementsAttr,
+ SparseElementsAttr
+ ];
+}
+
+def BuiltinDialectTypes : DialectTypes<"Builtin"> {
+ let elems = [
+ IntegerType,
+ IndexType,
+ FunctionType,
+ BFloat16Type,
+ Float16Type,
+ Float32Type,
+ Float64Type,
+ Float80Type,
+ Float128Type,
+ ComplexType,
+ MemRefType,
+ MemRefTypeWithMemSpace,
+ NoneType,
+ RankedTensorType,
+ RankedTensorTypeWithEncoding,
+ TupleType,
+ UnrankedMemRefType,
+ UnrankedMemRefTypeWithMemSpace,
+ UnrankedTensorType,
+ VectorType,
+ VectorTypeWithScalableDims
+ ];
+}
+
+#endif // BUILTIN_BYTECODE
diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
new file mode 100644
index 000000000000..8cadf978b347
--- /dev/null
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -0,0 +1,159 @@
+//===-- BytecodeBase.td - Base bytecode R/W defs -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the base bytecode reader/writer definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BYTECODE_BASE
+#define BYTECODE_BASE
+
+class Bytecode<string parse="", string build="", string print="", string get="", string t=""> {
+ // Template for parsing.
+ // $_reader == dialect bytecode reader
+ // $_resultType == result type of parsed instance
+ // $_var == variable being parsed
+ // If parser is not specified, then the parse of members is used.
+ string cParser = parse;
+
+ // Template for building from parsed.
+ // $_resultType == result type of parsed instance
+ // $_args == args/members comma separated
+ string cBuilder = build;
+
+ // Template for printing.
+ // $_writer == dialect bytecode writer
+ // $_name == parent attribute/type name
+ // $_getter == getter
+ string cPrinter = print;
+
+ // Template for getter from in memory form.
+ // $_attrType == attribute/type
+ // $_member == member instance
+ // $_getMember == get + UpperCamelFromSnake($_member)
+ string cGetter = get;
+
+ // Type built.
+ // Note: if cType is empty, then name of def is used.
+ string cType = t;
+
+ // Predicate guarding parse method as an Attribute/Type could have multiple
+ // parse methods, specify predicates to be orthogonal and cover entire
+ // "print space" to avoid order dependence.
+ // If empty then method is unconditional.
+ // $_val == predicate function to apply on value dyn_casted to cType.
+ string printerPredicate = "";
+}
+
+class WithParser<string p="", Bytecode t=Bytecode<>> :
+ Bytecode<p, t.cBuilder, t.cPrinter, t.cGetter, t.cType>;
+class WithBuilder<string b="", Bytecode t=Bytecode<>> :
+ Bytecode<t.cParser, b, t.cPrinter, t.cGetter, t.cType>;
+class WithPrinter<string p="", Bytecode t=Bytecode<>> :
+ Bytecode<t.cParser, t.cBuilder, p, t.cGetter, t.cType>;
+class WithType<string ty="", Bytecode t=Bytecode<>> :
+ Bytecode<t.cParser, t.cBuilder, t.cPrinter, t.cGetter, ty>;
+class WithGetter<string g="", Bytecode t=Bytecode<>> :
+ Bytecode<t.cParser, t.cBuilder, t.cPrinter, g, t.cType>;
+
+class CompositeBytecode<string t = ""> : WithType<t>;
+
+class AttributeKind :
+ WithParser <"succeeded($_reader.readAttribute($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeAttribute($_getter)">>>;
+def Attribute : AttributeKind;
+class TypeKind :
+ WithParser <"succeeded($_reader.readType($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeType($_getter)">>>;
+def Type : TypeKind;
+def VarInt :
+ WithParser <"succeeded($_reader.readVarInt($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeVarInt($_getter)",
+ WithType <"uint64_t">>>>;
+def SignedVarInt :
+ WithParser <"succeeded($_reader.readSignedVarInt($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeSignedVarInt($_getter)",
+ WithGetter<"$_attrType",
+ WithType <"int64_t">>>>>;
+def Blob :
+ WithParser <"succeeded($_reader.readBlob($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeOwnedBlob($_getter)",
+ WithType <"ArrayRef<char>">>>>;
+
+class KnownWidthAPInt<string s> :
+ WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeAPIntWithKnownWidth($_getter)",
+ WithType <"FailureOr<APInt>">>>>;
+class KnownSemanticsAPFloat<string s> :
+ WithParser <"succeeded(readAPFloatWithKnownSemantics($_reader, " # s # ", $_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeAPFloatWithKnownSemantics($_getter)",
+ WithType <"FailureOr<APFloat>">>>>;
+class ResourceHandle<string s> :
+ WithParser <"succeeded(readResourceHandle<" # s # ">($_reader, $_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeResourceHandle($_getter)",
+ WithType <"FailureOr<" # s # ">">>>>;
+
+// Helper to define variable that is defined later but not parsed nor printed.
+class LocalVar<string t, string d> :
+ WithParser <"(($_var = " # d # "), true)",
+ WithBuilder<"$_args",
+ WithPrinter<"",
+ WithType <t>>>>;
+
+// Array instances.
+class Array<Bytecode t> {
+ Bytecode elemT = t;
+
+ string cBuilder = "$_args";
+}
+
+// Define dialect attribute or type.
+class DialectAttrOrType<dag d> {
+ // Any members starting with underscore is not fed to create function but
+ // treated as purely local variable.
+ dag members = d;
+
+ // When needing to specify a custom return type.
+ string cType = "";
+
+ // Any post-processing that needs to be done.
+ code postProcess = "";
+}
+
+class DialectAttribute<dag d> : DialectAttrOrType<d>, AttributeKind {
+ let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))";
+ let cBuilder = "get<$_resultType>(context, $_args)";
+}
+class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
+ let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
+ let cBuilder = "get<$_resultType>(context, $_args)";
+}
+
+class DialectAttributes<string d> {
+ string dialect = d;
+ list<DialectAttrOrType> elems;
+}
+
+class DialectTypes<string d> {
+ string dialect = d;
+ list<DialectAttrOrType> elems;
+}
+
+def attr;
+def type;
+
+#endif // BYTECODE_BASE
+
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 78d41d6dc4ab..404f130022f1 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -17,6 +17,10 @@ mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(MLIRBuiltinDialectIncGen)
+set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td)
+mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin")
+add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen)
+
set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td)
mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 22a563dd7b2a..40af5f3b1744 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -8,6 +8,7 @@
#include "BuiltinDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -17,546 +18,61 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Encoding
-//===----------------------------------------------------------------------===//
-
-namespace {
-namespace builtin_encoding {
-/// This enum contains marker codes used to indicate which attribute is
-/// currently being decoded, and how it should be decoded. The order of these
-/// codes should generally be unchanged, as any changes will inevitably break
-/// compatibility with older bytecode.
-enum AttributeCode {
- /// ArrayAttr {
- /// elements: Attribute[]
- /// }
- ///
- kArrayAttr = 0,
-
- /// DictionaryAttr {
- /// attrs: <StringAttr, Attribute>[]
- /// }
- kDictionaryAttr = 1,
-
- /// StringAttr {
- /// value: string
- /// }
- kStringAttr = 2,
-
- /// StringAttrWithType {
- /// value: string,
- /// type: Type
- /// }
- /// A variant of StringAttr with a type.
- kStringAttrWithType = 3,
-
- /// FlatSymbolRefAttr {
- /// rootReference: StringAttr
- /// }
- /// A variant of SymbolRefAttr with no leaf references.
- kFlatSymbolRefAttr = 4,
-
- /// SymbolRefAttr {
- /// rootReference: StringAttr,
- /// leafReferences: FlatSymbolRefAttr[]
- /// }
- kSymbolRefAttr = 5,
-
- /// TypeAttr {
- /// value: Type
- /// }
- kTypeAttr = 6,
-
- /// UnitAttr {
- /// }
- kUnitAttr = 7,
-
- /// IntegerAttr {
- /// type: Type
- /// value: APInt,
- /// }
- kIntegerAttr = 8,
-
- /// FloatAttr {
- /// type: FloatType
- /// value: APFloat
- /// }
- kFloatAttr = 9,
-
- /// CallSiteLoc {
- /// callee: LocationAttr,
- /// caller: LocationAttr
- /// }
- kCallSiteLoc = 10,
-
- /// FileLineColLoc {
- /// file: StringAttr,
- /// line: varint,
- /// column: varint
- /// }
- kFileLineColLoc = 11,
-
- /// FusedLoc {
- /// locations: LocationAttr[]
- /// }
- kFusedLoc = 12,
-
- /// FusedLocWithMetadata {
- /// locations: LocationAttr[],
- /// metadata: Attribute
- /// }
- /// A variant of FusedLoc with metadata.
- kFusedLocWithMetadata = 13,
-
- /// NameLoc {
- /// name: StringAttr,
- /// childLoc: LocationAttr
- /// }
- kNameLoc = 14,
-
- /// UnknownLoc {
- /// }
- kUnknownLoc = 15,
-
- /// DenseResourceElementsAttr {
- /// type: Type,
- /// handle: ResourceHandle
- /// }
- kDenseResourceElementsAttr = 16,
-
- /// DenseArrayAttr {
- /// type: RankedTensorType,
- /// data: blob
- /// }
- kDenseArrayAttr = 17,
-
- /// DenseIntOrFPElementsAttr {
- /// type: ShapedType,
- /// data: blob
- /// }
- kDenseIntOrFPElementsAttr = 18,
-
- /// DenseStringElementsAttr {
- /// type: ShapedType,
- /// isSplat: varint,
- /// data: string[]
- /// }
- kDenseStringElementsAttr = 19,
-
- /// SparseElementsAttr {
- /// type: ShapedType,
- /// indices: DenseIntElementsAttr,
- /// values: DenseElementsAttr
- /// }
- kSparseElementsAttr = 20,
-};
-
-/// This enum contains marker codes used to indicate which type is currently
-/// being decoded, and how it should be decoded. The order of these codes should
-/// generally be unchanged, as any changes will inevitably break compatibility
-/// with older bytecode.
-enum TypeCode {
- /// IntegerType {
- /// widthAndSignedness: varint // (width << 2) | (signedness)
- /// }
- ///
- kIntegerType = 0,
-
- /// IndexType {
- /// }
- ///
- kIndexType = 1,
-
- /// FunctionType {
- /// inputs: Type[],
- /// results: Type[]
- /// }
- ///
- kFunctionType = 2,
-
- /// BFloat16Type {
- /// }
- ///
- kBFloat16Type = 3,
-
- /// Float16Type {
- /// }
- ///
- kFloat16Type = 4,
-
- /// Float32Type {
- /// }
- ///
- kFloat32Type = 5,
-
- /// Float64Type {
- /// }
- ///
- kFloat64Type = 6,
-
- /// Float80Type {
- /// }
- ///
- kFloat80Type = 7,
-
- /// Float128Type {
- /// }
- ///
- kFloat128Type = 8,
-
- /// ComplexType {
- /// elementType: Type
- /// }
- ///
- kComplexType = 9,
-
- /// MemRefType {
- /// shape: svarint[],
- /// elementType: Type,
- /// layout: Attribute
- /// }
- ///
- kMemRefType = 10,
-
- /// MemRefTypeWithMemSpace {
- /// memorySpace: Attribute,
- /// shape: svarint[],
- /// elementType: Type,
- /// layout: Attribute
- /// }
- /// Variant of MemRefType with non-default memory space.
- kMemRefTypeWithMemSpace = 11,
-
- /// NoneType {
- /// }
- ///
- kNoneType = 12,
-
- /// RankedTensorType {
- /// shape: svarint[],
- /// elementType: Type,
- /// }
- ///
- kRankedTensorType = 13,
-
- /// RankedTensorTypeWithEncoding {
- /// encoding: Attribute,
- /// shape: svarint[],
- /// elementType: Type
- /// }
- /// Variant of RankedTensorType with an encoding.
- kRankedTensorTypeWithEncoding = 14,
-
- /// TupleType {
- /// elementTypes: Type[]
- /// }
- kTupleType = 15,
-
- /// UnrankedMemRefType {
- /// shape: svarint[]
- /// }
- ///
- kUnrankedMemRefType = 16,
-
- /// UnrankedMemRefTypeWithMemSpace {
- /// memorySpace: Attribute,
- /// shape: svarint[]
- /// }
- /// Variant of UnrankedMemRefType with non-default memory space.
- kUnrankedMemRefTypeWithMemSpace = 17,
-
- /// UnrankedTensorType {
- /// elementType: Type
- /// }
- ///
- kUnrankedTensorType = 18,
-
- /// VectorType {
- /// shape: svarint[],
- /// elementType: Type
- /// }
- ///
- kVectorType = 19,
-
- /// VectorTypeWithScalableDims {
- /// numScalableDims: varint,
- /// shape: svarint[],
- /// elementType: Type
- /// }
- /// Variant of VectorType with scalable dimensions.
- kVectorTypeWithScalableDims = 20,
-};
-
-} // namespace builtin_encoding
-} // namespace
-
-//===----------------------------------------------------------------------===//
// BuiltinDialectBytecodeInterface
//===----------------------------------------------------------------------===//
namespace {
-/// This class implements the bytecode interface for the builtin dialect.
-struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
- BuiltinDialectBytecodeInterface(Dialect *dialect)
- : BytecodeDialectInterface(dialect) {}
-
- //===--------------------------------------------------------------------===//
- // Attributes
-
- Attribute readAttribute(DialectBytecodeReader &reader) const override;
- ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
- DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const;
- DenseElementsAttr
- readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const;
- DenseStringElementsAttr
- readDenseStringElementsAttr(DialectBytecodeReader &reader) const;
- DenseResourceElementsAttr
- readDenseResourceElementsAttr(DialectBytecodeReader &reader) const;
- DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
- FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
- IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
- SparseElementsAttr
- readSparseElementsAttr(DialectBytecodeReader &reader) const;
- StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const;
- SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader,
- bool hasNestedRefs) const;
- TypeAttr readTypeAttr(DialectBytecodeReader &reader) const;
-
- LocationAttr readCallSiteLoc(DialectBytecodeReader &reader) const;
- LocationAttr readFileLineColLoc(DialectBytecodeReader &reader) const;
- LocationAttr readFusedLoc(DialectBytecodeReader &reader,
- bool hasMetadata) const;
- LocationAttr readNameLoc(DialectBytecodeReader &reader) const;
-
- LogicalResult writeAttribute(Attribute attr,
- DialectBytecodeWriter &writer) const override;
- void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
- void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const;
- void write(DenseIntOrFPElementsAttr attr,
- DialectBytecodeWriter &writer) const;
- void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const;
- void write(DenseResourceElementsAttr attr,
- DialectBytecodeWriter &writer) const;
- void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
- void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
- void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
- void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const;
- void write(StringAttr attr, DialectBytecodeWriter &writer) const;
- void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const;
- void write(TypeAttr attr, DialectBytecodeWriter &writer) const;
-
- void write(CallSiteLoc attr, DialectBytecodeWriter &writer) const;
- void write(FileLineColLoc attr, DialectBytecodeWriter &writer) const;
- void write(FusedLoc attr, DialectBytecodeWriter &writer) const;
- void write(NameLoc attr, DialectBytecodeWriter &writer) const;
- LogicalResult write(OpaqueLoc attr, DialectBytecodeWriter &writer) const;
-
- //===--------------------------------------------------------------------===//
- // Types
-
- Type readType(DialectBytecodeReader &reader) const override;
- ComplexType readComplexType(DialectBytecodeReader &reader) const;
- IntegerType readIntegerType(DialectBytecodeReader &reader) const;
- FunctionType readFunctionType(DialectBytecodeReader &reader) const;
- MemRefType readMemRefType(DialectBytecodeReader &reader,
- bool hasMemSpace) const;
- RankedTensorType readRankedTensorType(DialectBytecodeReader &reader,
- bool hasEncoding) const;
- TupleType readTupleType(DialectBytecodeReader &reader) const;
- UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader,
- bool hasMemSpace) const;
- UnrankedTensorType
- readUnrankedTensorType(DialectBytecodeReader &reader) const;
- VectorType readVectorType(DialectBytecodeReader &reader,
- bool hasScalableDims) const;
-
- LogicalResult writeType(Type type,
- DialectBytecodeWriter &writer) const override;
- void write(ComplexType type, DialectBytecodeWriter &writer) const;
- void write(IntegerType type, DialectBytecodeWriter &writer) const;
- void write(FunctionType type, DialectBytecodeWriter &writer) const;
- void write(MemRefType type, DialectBytecodeWriter &writer) const;
- void write(RankedTensorType type, DialectBytecodeWriter &writer) const;
- void write(TupleType type, DialectBytecodeWriter &writer) const;
- void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const;
- void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const;
- void write(VectorType type, DialectBytecodeWriter &writer) const;
-};
-} // namespace
-
-void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) {
- dialect->addInterfaces<BuiltinDialectBytecodeInterface>();
-}
-
-//===----------------------------------------------------------------------===//
-// Attributes
-//===----------------------------------------------------------------------===//
-
-Attribute BuiltinDialectBytecodeInterface::readAttribute(
- DialectBytecodeReader &reader) const {
- uint64_t code;
- if (failed(reader.readVarInt(code)))
- return Attribute();
- switch (code) {
- case builtin_encoding::kArrayAttr:
- return readArrayAttr(reader);
- case builtin_encoding::kDictionaryAttr:
- return readDictionaryAttr(reader);
- case builtin_encoding::kStringAttr:
- return readStringAttr(reader, /*hasType=*/false);
- case builtin_encoding::kStringAttrWithType:
- return readStringAttr(reader, /*hasType=*/true);
- case builtin_encoding::kFlatSymbolRefAttr:
- return readSymbolRefAttr(reader, /*hasNestedRefs=*/false);
- case builtin_encoding::kSymbolRefAttr:
- return readSymbolRefAttr(reader, /*hasNestedRefs=*/true);
- case builtin_encoding::kTypeAttr:
- return readTypeAttr(reader);
- case builtin_encoding::kUnitAttr:
- return UnitAttr::get(getContext());
- case builtin_encoding::kIntegerAttr:
- return readIntegerAttr(reader);
- case builtin_encoding::kFloatAttr:
- return readFloatAttr(reader);
- case builtin_encoding::kCallSiteLoc:
- return readCallSiteLoc(reader);
- case builtin_encoding::kFileLineColLoc:
- return readFileLineColLoc(reader);
- case builtin_encoding::kFusedLoc:
- return readFusedLoc(reader, /*hasMetadata=*/false);
- case builtin_encoding::kFusedLocWithMetadata:
- return readFusedLoc(reader, /*hasMetadata=*/true);
- case builtin_encoding::kNameLoc:
- return readNameLoc(reader);
- case builtin_encoding::kUnknownLoc:
- return UnknownLoc::get(getContext());
- case builtin_encoding::kDenseResourceElementsAttr:
- return readDenseResourceElementsAttr(reader);
- case builtin_encoding::kDenseArrayAttr:
- return readDenseArrayAttr(reader);
- case builtin_encoding::kDenseIntOrFPElementsAttr:
- return readDenseIntOrFPElementsAttr(reader);
- case builtin_encoding::kDenseStringElementsAttr:
- return readDenseStringElementsAttr(reader);
- case builtin_encoding::kSparseElementsAttr:
- return readSparseElementsAttr(reader);
- default:
- reader.emitError() << "unknown builtin attribute code: " << code;
- return Attribute();
- }
-}
-
-LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
- Attribute attr, DialectBytecodeWriter &writer) const {
- return TypeSwitch<Attribute, LogicalResult>(attr)
- .Case<ArrayAttr, DenseArrayAttr, DenseIntOrFPElementsAttr,
- DenseStringElementsAttr, DenseResourceElementsAttr, DictionaryAttr,
- FloatAttr, IntegerAttr, SparseElementsAttr, StringAttr,
- SymbolRefAttr, TypeAttr>([&](auto attr) {
- write(attr, writer);
- return success();
- })
- .Case<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc>([&](auto attr) {
- write(attr, writer);
- return success();
- })
- .Case([&](OpaqueLoc attr) { return write(attr, writer); })
- .Case([&](UnitAttr) {
- writer.writeVarInt(builtin_encoding::kUnitAttr);
- return success();
- })
- .Case([&](UnknownLoc) {
- writer.writeVarInt(builtin_encoding::kUnknownLoc);
- return success();
- })
- .Default([&](Attribute) { return failure(); });
-}
-
-//===----------------------------------------------------------------------===//
-// ArrayAttr
-
-ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr(
- DialectBytecodeReader &reader) const {
- SmallVector<Attribute> elements;
- if (failed(reader.readAttributes(elements)))
- return ArrayAttr();
- return ArrayAttr::get(getContext(), elements);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- ArrayAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kArrayAttr);
- writer.writeAttributes(attr.getValue());
-}
//===----------------------------------------------------------------------===//
-// DenseArrayAttr
+// Utility functions
-DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
- DialectBytecodeReader &reader) const {
- Type elementType;
- uint64_t size;
- ArrayRef<char> blob;
- if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) ||
- failed(reader.readBlob(blob)))
- return DenseArrayAttr();
- return DenseArrayAttr::get(elementType, size, blob);
-}
+// TODO: Move these to separate file.
-void BuiltinDialectBytecodeInterface::write(
- DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
- writer.writeType(attr.getElementType());
- writer.writeVarInt(attr.getSize());
- writer.writeOwnedBlob(attr.getRawData());
+// Returns the bitwidth if known, else return 0.
+static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
+ return intType.getWidth();
+ } else if (type.isa<IndexType>()) {
+ return IndexType::kInternalStorageBitWidth;
+ }
+ reader.emitError()
+ << "expected integer or index type for IntegerAttr, but got: " << type;
+ return 0;
}
-//===----------------------------------------------------------------------===//
-// DenseIntOrFPElementsAttr
-
-DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr(
- DialectBytecodeReader &reader) const {
- ShapedType type;
- ArrayRef<char> blob;
- if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
- return DenseIntOrFPElementsAttr();
- return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob);
+static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader,
+ Type type, FailureOr<APInt> &val) {
+ unsigned bitWidth = getIntegerBitWidth(reader, type);
+ if (bitWidth == 0)
+ return failure();
+ val = reader.readAPIntWithKnownWidth(bitWidth);
+ return val;
}
-void BuiltinDialectBytecodeInterface::write(
- DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr);
- writer.writeType(attr.getType());
- writer.writeOwnedBlob(attr.getRawData());
+static LogicalResult
+readAPFloatWithKnownSemantics(DialectBytecodeReader &reader, Type type,
+ FailureOr<APFloat> &val) {
+ auto ftype = dyn_cast<FloatType>(type);
+ if (!ftype)
+ return failure();
+ val = reader.readAPFloatWithKnownSemantics(ftype.getFloatSemantics());
+ return success();
}
-//===----------------------------------------------------------------------===//
-// DenseStringElementsAttr
-
-DenseStringElementsAttr
-BuiltinDialectBytecodeInterface::readDenseStringElementsAttr(
- DialectBytecodeReader &reader) const {
- ShapedType type;
- uint64_t isSplat;
- if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat)))
- return DenseStringElementsAttr();
-
- SmallVector<StringRef> values(isSplat ? 1 : type.getNumElements());
- for (StringRef &value : values)
+LogicalResult
+readPotentiallySplatString(DialectBytecodeReader &reader, ShapedType type,
+ bool isSplat,
+ SmallVectorImpl<StringRef> &rawStringData) {
+ rawStringData.resize(isSplat ? 1 : type.getNumElements());
+ for (StringRef &value : rawStringData)
if (failed(reader.readString(value)))
- return DenseStringElementsAttr();
- return DenseStringElementsAttr::get(type, values);
+ return failure();
+ return success();
}
-void BuiltinDialectBytecodeInterface::write(
- DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr);
- writer.writeType(attr.getType());
-
+void writePotentiallySplatString(DialectBytecodeWriter &writer,
+ DenseStringElementsAttr attr) {
bool isSplat = attr.isSplat();
- writer.writeVarInt(isSplat);
-
- // If the attribute is a splat, only write out the single value.
if (isSplat)
return writer.writeOwnedString(attr.getRawStringData().front());
@@ -564,614 +80,39 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeOwnedString(str);
}
-//===----------------------------------------------------------------------===//
-// DenseResourceElementsAttr
-
-DenseResourceElementsAttr
-BuiltinDialectBytecodeInterface::readDenseResourceElementsAttr(
- DialectBytecodeReader &reader) const {
- ShapedType type;
- if (failed(reader.readType(type)))
- return DenseResourceElementsAttr();
-
- FailureOr<DenseResourceElementsHandle> handle =
- reader.readResourceHandle<DenseResourceElementsHandle>();
- if (failed(handle))
- return DenseResourceElementsAttr();
-
- return DenseResourceElementsAttr::get(type, *handle);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kDenseResourceElementsAttr);
- writer.writeType(attr.getType());
- writer.writeResourceHandle(attr.getRawHandle());
-}
-
-//===----------------------------------------------------------------------===//
-// DictionaryAttr
-
-DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
- DialectBytecodeReader &reader) const {
- auto readNamedAttr = [&]() -> FailureOr<NamedAttribute> {
- StringAttr name;
- Attribute value;
- if (failed(reader.readAttribute(name)) ||
- failed(reader.readAttribute(value)))
- return failure();
- return NamedAttribute(name, value);
- };
- SmallVector<NamedAttribute> attrs;
- if (failed(reader.readList(attrs, readNamedAttr)))
- return DictionaryAttr();
- return DictionaryAttr::get(getContext(), attrs);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- DictionaryAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kDictionaryAttr);
- writer.writeList(attr.getValue(), [&](NamedAttribute attr) {
- writer.writeAttribute(attr.getName());
- writer.writeAttribute(attr.getValue());
- });
-}
-
-//===----------------------------------------------------------------------===//
-// FloatAttr
-
-FloatAttr BuiltinDialectBytecodeInterface::readFloatAttr(
- DialectBytecodeReader &reader) const {
- FloatType type;
- if (failed(reader.readType(type)))
- return FloatAttr();
- FailureOr<APFloat> value =
- reader.readAPFloatWithKnownSemantics(type.getFloatSemantics());
- if (failed(value))
- return FloatAttr();
- return FloatAttr::get(type, *value);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- FloatAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kFloatAttr);
- writer.writeType(attr.getType());
- writer.writeAPFloatWithKnownSemantics(attr.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// IntegerAttr
-
-IntegerAttr BuiltinDialectBytecodeInterface::readIntegerAttr(
- DialectBytecodeReader &reader) const {
- Type type;
- if (failed(reader.readType(type)))
- return IntegerAttr();
-
- // Extract the value storage width from the type.
- unsigned bitWidth;
- if (auto intType = type.dyn_cast<IntegerType>()) {
- bitWidth = intType.getWidth();
- } else if (type.isa<IndexType>()) {
- bitWidth = IndexType::kInternalStorageBitWidth;
- } else {
- reader.emitError()
- << "expected integer or index type for IntegerAttr, but got: " << type;
- return IntegerAttr();
- }
-
- FailureOr<APInt> value = reader.readAPIntWithKnownWidth(bitWidth);
- if (failed(value))
- return IntegerAttr();
- return IntegerAttr::get(type, *value);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- IntegerAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kIntegerAttr);
- writer.writeType(attr.getType());
- writer.writeAPIntWithKnownWidth(attr.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// SparseElementsAttr
-
-SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr(
- DialectBytecodeReader &reader) const {
- ShapedType type;
- DenseIntElementsAttr indices;
- DenseElementsAttr values;
- if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) ||
- failed(reader.readAttribute(values)))
- return SparseElementsAttr();
- return SparseElementsAttr::get(type, indices, values);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- SparseElementsAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kSparseElementsAttr);
- writer.writeType(attr.getType());
- writer.writeAttribute(attr.getIndices());
- writer.writeAttribute(attr.getValues());
-}
-
-//===----------------------------------------------------------------------===//
-// StringAttr
-
-StringAttr
-BuiltinDialectBytecodeInterface::readStringAttr(DialectBytecodeReader &reader,
- bool hasType) const {
- StringRef string;
- if (failed(reader.readString(string)))
- return StringAttr();
-
- // Read the type if present.
- Type type;
- if (!hasType)
- type = NoneType::get(getContext());
- else if (failed(reader.readType(type)))
- return StringAttr();
- return StringAttr::get(string, type);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- StringAttr attr, DialectBytecodeWriter &writer) const {
- // We only encode the type if it isn't NoneType, which is significantly less
- // common.
- Type type = attr.getType();
- if (!type.isa<NoneType>()) {
- writer.writeVarInt(builtin_encoding::kStringAttrWithType);
- writer.writeOwnedString(attr.getValue());
- writer.writeType(type);
- return;
- }
- writer.writeVarInt(builtin_encoding::kStringAttr);
- writer.writeOwnedString(attr.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// SymbolRefAttr
+#include "mlir/IR/BuiltinDialectBytecode.cpp.inc"
-SymbolRefAttr BuiltinDialectBytecodeInterface::readSymbolRefAttr(
- DialectBytecodeReader &reader, bool hasNestedRefs) const {
- StringAttr rootReference;
- if (failed(reader.readAttribute(rootReference)))
- return SymbolRefAttr();
- SmallVector<FlatSymbolRefAttr> nestedReferences;
- if (hasNestedRefs && failed(reader.readAttributes(nestedReferences)))
- return SymbolRefAttr();
- return SymbolRefAttr::get(rootReference, nestedReferences);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- SymbolRefAttr attr, DialectBytecodeWriter &writer) const {
- ArrayRef<FlatSymbolRefAttr> nestedRefs = attr.getNestedReferences();
- writer.writeVarInt(nestedRefs.empty() ? builtin_encoding::kFlatSymbolRefAttr
- : builtin_encoding::kSymbolRefAttr);
-
- writer.writeAttribute(attr.getRootReference());
- if (!nestedRefs.empty())
- writer.writeAttributes(nestedRefs);
-}
-
-//===----------------------------------------------------------------------===//
-// TypeAttr
-
-TypeAttr BuiltinDialectBytecodeInterface::readTypeAttr(
- DialectBytecodeReader &reader) const {
- Type type;
- if (failed(reader.readType(type)))
- return TypeAttr();
- return TypeAttr::get(type);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- TypeAttr attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kTypeAttr);
- writer.writeType(attr.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// CallSiteLoc
-
-LocationAttr BuiltinDialectBytecodeInterface::readCallSiteLoc(
- DialectBytecodeReader &reader) const {
- LocationAttr callee, caller;
- if (failed(reader.readAttribute(callee)) ||
- failed(reader.readAttribute(caller)))
- return LocationAttr();
- return CallSiteLoc::get(callee, caller);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- CallSiteLoc attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kCallSiteLoc);
- writer.writeAttribute(attr.getCallee());
- writer.writeAttribute(attr.getCaller());
-}
-
-//===----------------------------------------------------------------------===//
-// FileLineColLoc
-
-LocationAttr BuiltinDialectBytecodeInterface::readFileLineColLoc(
- DialectBytecodeReader &reader) const {
- StringAttr filename;
- uint64_t line, column;
- if (failed(reader.readAttribute(filename)) ||
- failed(reader.readVarInt(line)) || failed(reader.readVarInt(column)))
- return LocationAttr();
- return FileLineColLoc::get(filename, line, column);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- FileLineColLoc attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kFileLineColLoc);
- writer.writeAttribute(attr.getFilename());
- writer.writeVarInt(attr.getLine());
- writer.writeVarInt(attr.getColumn());
-}
-
-//===----------------------------------------------------------------------===//
-// FusedLoc
-
-LocationAttr
-BuiltinDialectBytecodeInterface::readFusedLoc(DialectBytecodeReader &reader,
- bool hasMetadata) const {
- // Parse the child locations.
- auto readLoc = [&]() -> FailureOr<Location> {
- LocationAttr locAttr;
- if (failed(reader.readAttribute(locAttr)))
- return failure();
- return Location(locAttr);
- };
- SmallVector<Location> locations;
- if (failed(reader.readList(locations, readLoc)))
- return LocationAttr();
-
- // Parse the metadata if present.
- Attribute metadata;
- if (hasMetadata && failed(reader.readAttribute(metadata)))
- return LocationAttr();
-
- return FusedLoc::get(locations, metadata, getContext());
-}
-
-void BuiltinDialectBytecodeInterface::write(
- FusedLoc attr, DialectBytecodeWriter &writer) const {
- if (Attribute metadata = attr.getMetadata()) {
- writer.writeVarInt(builtin_encoding::kFusedLocWithMetadata);
- writer.writeAttributes(attr.getLocations());
- writer.writeAttribute(metadata);
- } else {
- writer.writeVarInt(builtin_encoding::kFusedLoc);
- writer.writeAttributes(attr.getLocations());
- }
-}
-
-//===----------------------------------------------------------------------===//
-// NameLoc
-
-LocationAttr BuiltinDialectBytecodeInterface::readNameLoc(
- DialectBytecodeReader &reader) const {
- StringAttr name;
- LocationAttr childLoc;
- if (failed(reader.readAttribute(name)) ||
- failed(reader.readAttribute(childLoc)))
- return LocationAttr();
- return NameLoc::get(name, childLoc);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- NameLoc attr, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kNameLoc);
- writer.writeAttribute(attr.getName());
- writer.writeAttribute(attr.getChildLoc());
-}
-
-//===----------------------------------------------------------------------===//
-// OpaqueLoc
-
-LogicalResult
-BuiltinDialectBytecodeInterface::write(OpaqueLoc attr,
- DialectBytecodeWriter &writer) const {
- // We can't encode an OpaqueLoc directly given that it is in-memory only, so
- // encode the fallback instead.
- return writeAttribute(attr.getFallbackLocation(), writer);
-}
-
-//===----------------------------------------------------------------------===//
-// Types
-//===----------------------------------------------------------------------===//
+/// This class implements the bytecode interface for the builtin dialect.
+struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
+ BuiltinDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
-Type BuiltinDialectBytecodeInterface::readType(
- DialectBytecodeReader &reader) const {
- uint64_t code;
- if (failed(reader.readVarInt(code)))
- return Type();
- switch (code) {
- case builtin_encoding::kIntegerType:
- return readIntegerType(reader);
- case builtin_encoding::kIndexType:
- return IndexType::get(getContext());
- case builtin_encoding::kFunctionType:
- return readFunctionType(reader);
- case builtin_encoding::kBFloat16Type:
- return BFloat16Type::get(getContext());
- case builtin_encoding::kFloat16Type:
- return Float16Type::get(getContext());
- case builtin_encoding::kFloat32Type:
- return Float32Type::get(getContext());
- case builtin_encoding::kFloat64Type:
- return Float64Type::get(getContext());
- case builtin_encoding::kFloat80Type:
- return Float80Type::get(getContext());
- case builtin_encoding::kFloat128Type:
- return Float128Type::get(getContext());
- case builtin_encoding::kComplexType:
- return readComplexType(reader);
- case builtin_encoding::kMemRefType:
- return readMemRefType(reader, /*hasMemSpace=*/false);
- case builtin_encoding::kMemRefTypeWithMemSpace:
- return readMemRefType(reader, /*hasMemSpace=*/true);
- case builtin_encoding::kNoneType:
- return NoneType::get(getContext());
- case builtin_encoding::kRankedTensorType:
- return readRankedTensorType(reader, /*hasEncoding=*/false);
- case builtin_encoding::kRankedTensorTypeWithEncoding:
- return readRankedTensorType(reader, /*hasEncoding=*/true);
- case builtin_encoding::kTupleType:
- return readTupleType(reader);
- case builtin_encoding::kUnrankedMemRefType:
- return readUnrankedMemRefType(reader, /*hasMemSpace=*/false);
- case builtin_encoding::kUnrankedMemRefTypeWithMemSpace:
- return readUnrankedMemRefType(reader, /*hasMemSpace=*/true);
- case builtin_encoding::kUnrankedTensorType:
- return readUnrankedTensorType(reader);
- case builtin_encoding::kVectorType:
- return readVectorType(reader, /*hasScalableDims=*/false);
- case builtin_encoding::kVectorTypeWithScalableDims:
- return readVectorType(reader, /*hasScalableDims=*/true);
+ //===--------------------------------------------------------------------===//
+ // Attributes
- default:
- reader.emitError() << "unknown builtin type code: " << code;
- return Type();
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
}
-}
-
-LogicalResult BuiltinDialectBytecodeInterface::writeType(
- Type type, DialectBytecodeWriter &writer) const {
- return TypeSwitch<Type, LogicalResult>(type)
- .Case<ComplexType, IntegerType, FunctionType, MemRefType,
- RankedTensorType, TupleType, UnrankedMemRefType, UnrankedTensorType,
- VectorType>([&](auto type) {
- write(type, writer);
- return success();
- })
- .Case([&](IndexType) {
- return writer.writeVarInt(builtin_encoding::kIndexType), success();
- })
- .Case([&](BFloat16Type) {
- return writer.writeVarInt(builtin_encoding::kBFloat16Type), success();
- })
- .Case([&](Float16Type) {
- return writer.writeVarInt(builtin_encoding::kFloat16Type), success();
- })
- .Case([&](Float32Type) {
- return writer.writeVarInt(builtin_encoding::kFloat32Type), success();
- })
- .Case([&](Float64Type) {
- return writer.writeVarInt(builtin_encoding::kFloat64Type), success();
- })
- .Case([&](Float80Type) {
- return writer.writeVarInt(builtin_encoding::kFloat80Type), success();
- })
- .Case([&](Float128Type) {
- return writer.writeVarInt(builtin_encoding::kFloat128Type), success();
- })
- .Case([&](NoneType) {
- return writer.writeVarInt(builtin_encoding::kNoneType), success();
- })
- .Default([&](Type) { return failure(); });
-}
-
-//===----------------------------------------------------------------------===//
-// ComplexType
-
-ComplexType BuiltinDialectBytecodeInterface::readComplexType(
- DialectBytecodeReader &reader) const {
- Type elementType;
- if (failed(reader.readType(elementType)))
- return ComplexType();
- return ComplexType::get(elementType);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- ComplexType type, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kComplexType);
- writer.writeType(type.getElementType());
-}
-
-//===----------------------------------------------------------------------===//
-// IntegerType
-
-IntegerType BuiltinDialectBytecodeInterface::readIntegerType(
- DialectBytecodeReader &reader) const {
- uint64_t encoding;
- if (failed(reader.readVarInt(encoding)))
- return IntegerType();
- return IntegerType::get(
- getContext(), encoding >> 2,
- static_cast<IntegerType::SignednessSemantics>(encoding & 0x3));
-}
-
-void BuiltinDialectBytecodeInterface::write(
- IntegerType type, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kIntegerType);
- writer.writeVarInt((type.getWidth() << 2) | type.getSignedness());
-}
-
-//===----------------------------------------------------------------------===//
-// FunctionType
-
-FunctionType BuiltinDialectBytecodeInterface::readFunctionType(
- DialectBytecodeReader &reader) const {
- SmallVector<Type> inputs, results;
- if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results)))
- return FunctionType();
- return FunctionType::get(getContext(), inputs, results);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- FunctionType type, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kFunctionType);
- writer.writeTypes(type.getInputs());
- writer.writeTypes(type.getResults());
-}
-
-//===----------------------------------------------------------------------===//
-// MemRefType
-
-MemRefType
-BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader,
- bool hasMemSpace) const {
- Attribute memorySpace;
- if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
- return MemRefType();
- SmallVector<int64_t> shape;
- Type elementType;
- MemRefLayoutAttrInterface layout;
- if (failed(reader.readSignedVarInts(shape)) ||
- failed(reader.readType(elementType)) ||
- failed(reader.readAttribute(layout)))
- return MemRefType();
- return MemRefType::get(shape, elementType, layout, memorySpace);
-}
-void BuiltinDialectBytecodeInterface::write(
- MemRefType type, DialectBytecodeWriter &writer) const {
- if (Attribute memSpace = type.getMemorySpace()) {
- writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace);
- writer.writeAttribute(memSpace);
- } else {
- writer.writeVarInt(builtin_encoding::kMemRefType);
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
}
- writer.writeSignedVarInts(type.getShape());
- writer.writeType(type.getElementType());
- writer.writeAttribute(type.getLayout());
-}
-
-//===----------------------------------------------------------------------===//
-// RankedTensorType
-RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType(
- DialectBytecodeReader &reader, bool hasEncoding) const {
- Attribute encoding;
- if (hasEncoding && failed(reader.readAttribute(encoding)))
- return RankedTensorType();
- SmallVector<int64_t> shape;
- Type elementType;
- if (failed(reader.readSignedVarInts(shape)) ||
- failed(reader.readType(elementType)))
- return RankedTensorType();
- return RankedTensorType::get(shape, elementType, encoding);
-}
+ //===--------------------------------------------------------------------===//
+ // Types
-void BuiltinDialectBytecodeInterface::write(
- RankedTensorType type, DialectBytecodeWriter &writer) const {
- if (Attribute encoding = type.getEncoding()) {
- writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding);
- writer.writeAttribute(encoding);
- } else {
- writer.writeVarInt(builtin_encoding::kRankedTensorType);
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
}
- writer.writeSignedVarInts(type.getShape());
- writer.writeType(type.getElementType());
-}
-
-//===----------------------------------------------------------------------===//
-// TupleType
-
-TupleType BuiltinDialectBytecodeInterface::readTupleType(
- DialectBytecodeReader &reader) const {
- SmallVector<Type> elements;
- if (failed(reader.readTypes(elements)))
- return TupleType();
- return TupleType::get(getContext(), elements);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- TupleType type, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kTupleType);
- writer.writeTypes(type.getTypes());
-}
-
-//===----------------------------------------------------------------------===//
-// UnrankedMemRefType
-
-UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType(
- DialectBytecodeReader &reader, bool hasMemSpace) const {
- Attribute memorySpace;
- if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
- return UnrankedMemRefType();
- Type elementType;
- if (failed(reader.readType(elementType)))
- return UnrankedMemRefType();
- return UnrankedMemRefType::get(elementType, memorySpace);
-}
-void BuiltinDialectBytecodeInterface::write(
- UnrankedMemRefType type, DialectBytecodeWriter &writer) const {
- if (Attribute memSpace = type.getMemorySpace()) {
- writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace);
- writer.writeAttribute(memSpace);
- } else {
- writer.writeVarInt(builtin_encoding::kUnrankedMemRefType);
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
}
- writer.writeType(type.getElementType());
-}
-
-//===----------------------------------------------------------------------===//
-// UnrankedTensorType
-
-UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType(
- DialectBytecodeReader &reader) const {
- Type elementType;
- if (failed(reader.readType(elementType)))
- return UnrankedTensorType();
- return UnrankedTensorType::get(elementType);
-}
-
-void BuiltinDialectBytecodeInterface::write(
- UnrankedTensorType type, DialectBytecodeWriter &writer) const {
- writer.writeVarInt(builtin_encoding::kUnrankedTensorType);
- writer.writeType(type.getElementType());
-}
-
-//===----------------------------------------------------------------------===//
-// VectorType
-
-VectorType
-BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader,
- bool hasScalableDims) const {
- uint64_t numScalableDims = 0;
- if (hasScalableDims && failed(reader.readVarInt(numScalableDims)))
- return VectorType();
- SmallVector<int64_t> shape;
- Type elementType;
- if (failed(reader.readSignedVarInts(shape)) ||
- failed(reader.readType(elementType)))
- return VectorType();
- return VectorType::get(shape, elementType, numScalableDims);
-}
+};
+} // namespace
-void BuiltinDialectBytecodeInterface::write(
- VectorType type, DialectBytecodeWriter &writer) const {
- if (unsigned numScalableDims = type.getNumScalableDims()) {
- writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims);
- writer.writeVarInt(numScalableDims);
- } else {
- writer.writeVarInt(builtin_encoding::kVectorType);
- }
- writer.writeSignedVarInts(type.getShape());
- writer.writeType(type.getElementType());
+void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) {
+ dialect->addInterfaces<BuiltinDialectBytecodeInterface>();
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 4377ebe16055..b729282e627d 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -44,6 +44,7 @@ add_mlir_library(MLIRIR
DEPENDS
MLIRBuiltinAttributesIncGen
MLIRBuiltinAttributeInterfacesIncGen
+ MLIRBuiltinDialectBytecodeIncGen
MLIRBuiltinDialectIncGen
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
new file mode 100644
index 000000000000..f13bdd49413b
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -0,0 +1,467 @@
+//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/IndentedOstream.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include <regex>
+
+using namespace llvm;
+
+static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
+static llvm::cl::opt<std::string>
+ selectedBcDialect("bytecode-dialect",
+ llvm::cl::desc("The dialect to gen for"),
+ llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
+
+namespace {
+
+/// Helper class to generate C++ bytecode parser helpers.
+class Generator {
+public:
+ Generator(raw_ostream &output) : output(output) {}
+
+ /// Returns whether successfully emitted attribute/type parsers.
+ void emitParse(StringRef kind, Record &x);
+
+ /// Returns whether successfully emitted attribute/type printers.
+ void emitPrint(StringRef kind, StringRef type,
+ ArrayRef<std::pair<int64_t, Record *>> vec);
+
+ /// Emits parse dispatch table.
+ void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
+
+ /// Emits print dispatch table.
+ void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
+
+private:
+ /// Emits parse calls to construct given kind.
+ void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
+ ArrayRef<Init *> args, ArrayRef<std::string> argNames,
+ StringRef failure, mlir::raw_indented_ostream &ios);
+
+ /// Emits print instructions.
+ void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
+ StringRef name, mlir::raw_indented_ostream &ios);
+
+ raw_ostream &output;
+};
+} // namespace
+
+/// Helper to replace set of from strings to target in `s`.
+/// Assumed: non-overlapping replacements.
+static std::string format(StringRef templ,
+ std::map<std::string, std::string> &&map) {
+ std::string s = templ.str();
+ for (const auto &[from, to] : map)
+ // All replacements start with $, don't treat as anchor.
+ s = std::regex_replace(s, std::regex("\\" + from), to);
+ return s;
+}
+
+/// Return string with first character capitalized.
+static std::string capitalize(StringRef str) {
+ return ((Twine)toUpper(str[0]) + str.drop_front()).str();
+}
+
+/// Return the C++ type for the given record.
+static std::string getCType(Record *def) {
+ std::string format = "{0}";
+ if (def->isSubClassOf("Array")) {
+ def = def->getValueAsDef("elemT");
+ format = "SmallVector<{0}>";
+ }
+
+ StringRef cType = def->getValueAsString("cType");
+ if (cType.empty()) {
+ if (def->isAnonymous())
+ PrintFatalError(def->getLoc(), "Unable to determine cType");
+
+ return formatv(format.c_str(), def->getName().str());
+ }
+ return formatv(format.c_str(), cType.str());
+}
+
+void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
+ mlir::raw_indented_ostream os(output);
+ char const *head =
+ R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
+ os << formatv(head, capitalize(kind));
+ auto funScope = os.scope(" {\n", "}\n\n");
+
+ os << "uint64_t kind;\n";
+ os << "if (failed(reader.readVarInt(kind)))\n"
+ << " return " << capitalize(kind) << "();\n";
+ os << "switch (kind) ";
+ {
+ auto switchScope = os.scope("{\n", "}\n");
+ for (const auto &it : llvm::enumerate(vec)) {
+ os << formatv("case {1}:\n return read{0}(context, reader);\n",
+ it.value()->getName(), it.index());
+ }
+ os << "default:\n"
+ << " reader.emitError() << \"unknown attribute code: \" "
+ << "<< kind;\n"
+ << " return " << capitalize(kind) << "();\n";
+ }
+ os << "return " << capitalize(kind) << "();\n";
+}
+
+void Generator::emitParse(StringRef kind, Record &x) {
+ char const *head =
+ R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
+ mlir::raw_indented_ostream os(output);
+ std::string returnType = getCType(&x);
+ os << formatv(head, returnType, x.getName());
+ DagInit *members = x.getValueAsDag("members");
+ SmallVector<std::string> argNames =
+ llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
+ return init->getAsUnquotedString();
+ }));
+ StringRef builder = x.getValueAsString("cBuilder");
+ emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
+ returnType + "()", os);
+ os << "\n\n";
+}
+
+void printParseConditional(mlir::raw_indented_ostream &ios,
+ ArrayRef<Init *> args,
+ ArrayRef<std::string> argNames) {
+ ios << "if ";
+ auto parenScope = ios.scope("(", ") {");
+ ios.indent();
+
+ auto listHelperName = [](StringRef name) {
+ return formatv("read{0}", capitalize(name));
+ };
+
+ auto parsedArgs =
+ llvm::to_vector(make_filter_range(args, [](Init *const attr) {
+ Record *def = cast<DefInit>(attr)->getDef();
+ if (def->isSubClassOf("Array"))
+ return true;
+ return !def->getValueAsString("cParser").empty();
+ }));
+
+ interleave(
+ zip(parsedArgs, argNames),
+ [&](std::tuple<llvm::Init *&, const std::string &> it) {
+ Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
+ std::string parser;
+ if (auto optParser = attr->getValueAsOptionalString("cParser")) {
+ parser = *optParser;
+ } else if (attr->isSubClassOf("Array")) {
+ Record *def = attr->getValueAsDef("elemT");
+ bool composite = def->isSubClassOf("CompositeBytecode");
+ if (!composite && def->isSubClassOf("AttributeKind"))
+ parser = "succeeded($_reader.readAttributes($_var))";
+ else if (!composite && def->isSubClassOf("TypeKind"))
+ parser = "succeeded($_reader.readTypes($_var))";
+ else
+ parser = ("succeeded($_reader.readList($_var, " +
+ listHelperName(std::get<1>(it)) + "))")
+ .str();
+ } else {
+ PrintFatalError(attr->getLoc(), "No parser specified");
+ }
+ std::string type = getCType(attr);
+ ios << format(parser, {{"$_reader", "reader"},
+ {"$_resultType", type},
+ {"$_var", std::get<1>(it)}});
+ },
+ [&]() { ios << " &&\n"; });
+}
+
+void Generator::emitParseHelper(StringRef kind, StringRef returnType,
+ StringRef builder, ArrayRef<Init *> args,
+ ArrayRef<std::string> argNames,
+ StringRef failure,
+ mlir::raw_indented_ostream &ios) {
+ auto funScope = ios.scope("{\n", "}");
+
+ if (args.empty()) {
+ ios << formatv("return get<{0}>(context);\n", returnType);
+ return;
+ }
+
+ // Print decls.
+ std::string lastCType = "";
+ for (auto [arg, name] : zip(args, argNames)) {
+ DefInit *first = dyn_cast<DefInit>(arg);
+ if (!first)
+ PrintFatalError("Unexpected type for " + name);
+ Record *def = first->getDef();
+
+ // Create variable decls, if there are a block of same type then create
+ // comma separated list of them.
+ std::string cType = getCType(def);
+ if (lastCType == cType) {
+ ios << ", ";
+ } else {
+ if (!lastCType.empty())
+ ios << ";\n";
+ ios << cType << " ";
+ }
+ ios << name;
+ lastCType = cType;
+ }
+ ios << ";\n";
+
+ // Returns the name of the helper used in list parsing. E.g., the name of the
+ // lambda passed to array parsing.
+ auto listHelperName = [](StringRef name) {
+ return formatv("read{0}", capitalize(name));
+ };
+
+ // Emit list helper functions.
+ for (auto [arg, name] : zip(args, argNames)) {
+ Record *attr = cast<DefInit>(arg)->getDef();
+ if (!attr->isSubClassOf("Array"))
+ continue;
+
+ // TODO: Dedupe readers.
+ Record *def = attr->getValueAsDef("elemT");
+ if (!def->isSubClassOf("CompositeBytecode") &&
+ (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
+ continue;
+
+ std::string returnType = getCType(def);
+ ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
+ << returnType << "> ";
+ SmallVector<Init *> args;
+ SmallVector<std::string> argNames;
+ if (def->isSubClassOf("CompositeBytecode")) {
+ DagInit *members = def->getValueAsDag("members");
+ args = llvm::to_vector(members->getArgs());
+ argNames = llvm::to_vector(
+ map_range(members->getArgNames(), [](StringInit *init) {
+ return init->getAsUnquotedString();
+ }));
+ } else {
+ args = {def->getDefInit()};
+ argNames = {"temp"};
+ }
+ StringRef builder = def->getValueAsString("cBuilder");
+ emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
+ ios);
+ ios << ";\n";
+ }
+
+ // Print parse conditional.
+ printParseConditional(ios, args, argNames);
+
+ // Compute args to pass to create method.
+ auto passedArgs = llvm::to_vector(make_filter_range(
+ argNames, [](StringRef str) { return !str.starts_with("_"); }));
+ std::string argStr;
+ raw_string_ostream argStream(argStr);
+ interleaveComma(passedArgs, argStream,
+ [&](const std::string &str) { argStream << str; });
+ // Return the invoked constructor.
+ ios << "\nreturn "
+ << format(builder, {{"$_resultType", returnType.str()},
+ {"$_args", argStream.str()}})
+ << ";\n";
+ ios.unindent();
+
+ // TODO: Emit error in debug.
+ // This assumes the result types in error case can always be empty
+ // constructed.
+ ios << "}\nreturn " << failure << ";\n";
+}
+
+void Generator::emitPrint(StringRef kind, StringRef type,
+ ArrayRef<std::pair<int64_t, Record *>> vec) {
+ char const *head =
+ R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
+ mlir::raw_indented_ostream os(output);
+ os << formatv(head, type, kind);
+ auto funScope = os.scope("{\n", "}\n\n");
+
+ // Check that predicates specified if multiple bytecode instances.
+ for (llvm::Record *rec : make_second_range(vec)) {
+ StringRef pred = rec->getValueAsString("printerPredicate");
+ if (vec.size() > 1 && pred.empty()) {
+ for (auto [index, rec] : vec) {
+ (void)index;
+ StringRef pred = rec->getValueAsString("printerPredicate");
+ if (vec.size() > 1 && pred.empty())
+ PrintError(rec->getLoc(),
+ "Requires parsing predicate given common cType");
+ }
+ PrintFatalError("Unspecified for shared cType " + type);
+ }
+ }
+
+ for (auto [index, rec] : vec) {
+ StringRef pred = rec->getValueAsString("printerPredicate");
+ if (!pred.empty()) {
+ os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
+ os.indent();
+ }
+
+ os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
+ << ");\n";
+
+ auto *members = rec->getValueAsDag("members");
+ for (auto [arg, name] :
+ llvm::zip(members->getArgs(), members->getArgNames())) {
+ DefInit *def = dyn_cast<DefInit>(arg);
+ assert(def);
+ Record *memberRec = def->getDef();
+ emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
+ }
+
+ if (!pred.empty()) {
+ os.unindent();
+ os << "}\n";
+ }
+ }
+}
+
+void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
+ StringRef parent, StringRef name,
+ mlir::raw_indented_ostream &ios) {
+ std::string getter;
+ if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
+ cGetter && !cGetter->empty()) {
+ getter = format(
+ *cGetter,
+ {{"$_attrType", parent.str()},
+ {"$_member", name.str()},
+ {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
+ } else {
+ getter =
+ formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
+ .str();
+ }
+
+ if (memberRec->isSubClassOf("Array")) {
+ Record *def = memberRec->getValueAsDef("elemT");
+ if (!def->isSubClassOf("CompositeBytecode")) {
+ if (def->isSubClassOf("AttributeKind")) {
+ ios << "writer.writeAttributes(" << getter << ");\n";
+ return;
+ }
+ if (def->isSubClassOf("TypeKind")) {
+ ios << "writer.writeTypes(" << getter << ");\n";
+ return;
+ }
+ }
+ std::string returnType = getCType(def);
+ ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
+ << kind << ") ";
+ auto lambdaScope = ios.scope("{\n", "});\n");
+ return emitPrintHelper(def, kind, kind, kind, ios);
+ }
+ if (memberRec->isSubClassOf("CompositeBytecode")) {
+ auto *members = memberRec->getValueAsDag("members");
+ for (auto [arg, argName] :
+ zip(members->getArgs(), members->getArgNames())) {
+ DefInit *def = dyn_cast<DefInit>(arg);
+ assert(def);
+ emitPrintHelper(def->getDef(), kind, parent,
+ argName->getAsUnquotedString(), ios);
+ }
+ }
+
+ if (std::string printer = memberRec->getValueAsString("cPrinter").str();
+ !printer.empty())
+ ios << format(printer, {{"$_writer", "writer"},
+ {"$_name", kind.str()},
+ {"$_getter", getter}})
+ << ";\n";
+}
+
+void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
+ mlir::raw_indented_ostream os(output);
+ char const *head = R"(static LogicalResult write{0}({0} {1},
+ DialectBytecodeWriter &writer))";
+ os << formatv(head, capitalize(kind), kind);
+ auto funScope = os.scope(" {\n", "}\n\n");
+
+ os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
+ << ")";
+ auto switchScope = os.scope("", "");
+ for (StringRef type : vec) {
+ os << "\n.Case([&](" << type << " t)";
+ auto caseScope = os.scope(" {\n", "})");
+ os << "return write(t, writer), success();\n";
+ }
+ os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
+}
+
+namespace {
+/// Container of Attribute or Type for Dialect.
+struct AttrOrType {
+ std::vector<Record *> attr, type;
+};
+} // namespace
+
+static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
+ MapVector<StringRef, AttrOrType> dialectAttrOrType;
+ for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
+ if (!selectedBcDialect.empty() &&
+ it->getValueAsString("dialect") != selectedBcDialect)
+ continue;
+ dialectAttrOrType[it->getValueAsString("dialect")].attr =
+ it->getValueAsListOfDefs("elems");
+ }
+ for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
+ if (!selectedBcDialect.empty() &&
+ it->getValueAsString("dialect") != selectedBcDialect)
+ continue;
+ dialectAttrOrType[it->getValueAsString("dialect")].type =
+ it->getValueAsListOfDefs("elems");
+ }
+
+ if (dialectAttrOrType.size() != 1)
+ PrintFatalError("Single dialect per invocation required (either only "
+ "one in input file or specified via dialect option)");
+
+ auto it = dialectAttrOrType.front();
+ Generator gen(os);
+
+ SmallVector<std::vector<Record *> *, 2> vecs;
+ SmallVector<std::string, 2> kinds;
+ vecs.push_back(&it.second.attr);
+ kinds.push_back("attribute");
+ vecs.push_back(&it.second.type);
+ kinds.push_back("type");
+ for (auto [vec, kind] : zip(vecs, kinds)) {
+ // Handle Attribute/Type emission.
+ std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
+ for (auto kt : llvm::enumerate(*vec))
+ perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
+ for (const auto &jt : perType) {
+ for (auto kt : jt.second)
+ gen.emitParse(kind, *std::get<1>(kt));
+ gen.emitPrint(kind, jt.first, jt.second);
+ }
+ gen.emitParseDispatch(kind, *vec);
+
+ SmallVector<std::string> types;
+ for (const auto &it : perType) {
+ types.push_back(it.first);
+ }
+ gen.emitPrintDispatch(kind, types);
+ }
+
+ return false;
+}
+
+static mlir::GenRegistration
+ genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitBCRW(records, os);
+ });
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index fe217450d6fb..0835b6d27c71 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR
EXPORT MLIR
AttrOrTypeDefGen.cpp
AttrOrTypeFormatGen.cpp
+ BytecodeDialectGen.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 881a12426baf..fc189556b864 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -91,6 +91,7 @@ td_library(
],
includes = ["include"],
deps = [
+ ":BytecodeTdFiles",
":CallInterfacesTdFiles",
":CastInterfacesTdFiles",
":DataLayoutInterfacesTdFiles",
@@ -118,6 +119,20 @@ gentbl_cc_library(
)
gentbl_cc_library(
+ name = "BuiltinDialectBytecodeGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-bytecode", "-bytecode-dialect=Builtin"],
+ "include/mlir/IR/BuiltinDialectBytecode.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/IR/BuiltinDialectBytecode.td",
+ deps = [":BuiltinDialectTdFiles"],
+)
+
+gentbl_cc_library(
name = "BuiltinAttributesIncGen",
strip_include_prefix = "include",
tbl_outs = [
@@ -277,6 +292,7 @@ cc_library(
deps = [
":BuiltinAttributeInterfacesIncGen",
":BuiltinAttributesIncGen",
+ ":BuiltinDialectBytecodeGen",
":BuiltinDialectIncGen",
":BuiltinLocationAttributesIncGen",
":BuiltinOpsIncGen",
@@ -930,6 +946,12 @@ td_library(
)
td_library(
+ name = "BytecodeTdFiles",
+ srcs = ["include/mlir/IR/BytecodeBase.td"],
+ includes = ["include"],
+)
+
+td_library(
name = "CallInterfacesTdFiles",
srcs = ["include/mlir/Interfaces/CallInterfaces.td"],
includes = ["include"],