diff options
author | River Riddle <riddleriver@gmail.com> | 2020-12-01 14:30:18 -0800 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2020-12-01 15:05:50 -0800 |
commit | abfd1a8b3bc5ad8516a83c3ae7ba9f16032525ad (patch) | |
tree | bf3a05f6f0fc393c25f486729efaf612cf8c9ade | |
parent | c7dbaec396ef98b8bc6acb7631d2919449986add (diff) | |
download | llvm-abfd1a8b3bc5ad8516a83c3ae7ba9f16032525ad.tar.gz |
[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
23 files changed, 2841 insertions, 106 deletions
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index df49eb37b2a5..6b11c0dde809 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -108,7 +108,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { ```mlir // Apply `myConstraint` to the entities defined by `input`, `attr`, and // `op`. - pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; @@ -316,7 +316,7 @@ def PDLInterp_CheckTypeOp Example: ```mlir - pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest + pdl_interp.check_type %type is i32 -> ^matchDest, ^failureDest ``` }]; @@ -338,7 +338,7 @@ def PDLInterp_CreateAttributeOp Example: ```mlir - pdl_interp.create_attribute 10 : i64 + %attr = pdl_interp.create_attribute 10 : i64 ``` }]; @@ -369,7 +369,7 @@ def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> { Example: ```mlir - %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute ``` }]; @@ -772,7 +772,7 @@ def PDLInterp_SwitchAttributeOp Example: ```mlir - pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest + pdl_interp.switch_attribute %attr to [10, true](^10Dest, ^trueDest) -> ^defaultDest ``` }]; let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues); @@ -837,7 +837,7 @@ def PDLInterp_SwitchOperationNameOp Example: ```mlir - pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest + pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"](^fooDest, ^barDest) -> ^defaultDest ``` }]; @@ -874,7 +874,7 @@ def PDLInterp_SwitchResultCountOp Example: ```mlir - pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest + pdl_interp.switch_result_count of %op to [0, 2](^0Dest, ^2Dest) -> ^defaultDest ``` }]; diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index fc16effbba70..6cf2df9a1406 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -58,6 +58,7 @@ class SuccessorRange final SuccessorRange, BlockOperand *, Block *, Block *, Block *> { public: using RangeBaseT::RangeBaseT; + SuccessorRange(); SuccessorRange(Block *block); SuccessorRange(Operation *term); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 5b3c44868db2..3d5bc66ee9e2 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -69,6 +69,9 @@ public: /// Remove this operation from its parent block and delete it. void erase(); + /// Remove the operation from its parent block, but don't delete it. + void remove(); + /// Create a deep copy of this operation, remapping any operands that use /// values outside of the operation using the map that is provided (leaving /// them alone if no entry is present). Replaces references to cloned diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 96d6d1194b60..74899c9565fe 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -349,7 +349,7 @@ public: void *getAsOpaquePointer() const { return static_cast<void *>(representation.getOpaqueValue()); } - static OperationName getFromOpaquePointer(void *pointer); + static OperationName getFromOpaquePointer(const void *pointer); private: RepresentationUnion representation; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 2158f09cc469..4fdc0878c590 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -10,6 +10,7 @@ #define MLIR_PATTERNMATCHER_H #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" namespace mlir { @@ -226,6 +227,189 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern { }; //===----------------------------------------------------------------------===// +// PDLPatternModule +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// PDLValue + +/// Storage type of byte-code interpreter values. These are passed to constraint +/// functions as arguments. +class PDLValue { + /// The internal implementation type when the value is an Attribute, + /// Operation*, or Type. See `impl` below for more details. + using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>; + +public: + PDLValue(const PDLValue &other) : impl(other.impl) {} + PDLValue(std::nullptr_t = nullptr) : impl() {} + PDLValue(Attribute value) : impl(value) {} + PDLValue(Operation *value) : impl(value) {} + PDLValue(Type value) : impl(value) {} + PDLValue(Value value) : impl(value) {} + + /// Returns true if the type of the held value is `T`. + template <typename T> + std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const { + return impl.is<Value>(); + } + template <typename T> + std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const { + auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>(); + return attrOpTypeImpl && attrOpTypeImpl.is<T>(); + } + + /// Attempt to dynamically cast this value to type `T`, returns null if this + /// value is not an instance of `T`. + template <typename T> + std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const { + return impl.dyn_cast<T>(); + } + template <typename T> + std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const { + auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>(); + return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>(); + } + + /// Cast this value to type `T`, asserts if this value is not an instance of + /// `T`. + template <typename T> + std::enable_if_t<std::is_same<T, Value>::value, T> cast() const { + return impl.get<T>(); + } + template <typename T> + std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const { + return impl.get<AttrOpTypeImplT>().get<T>(); + } + + /// Get an opaque pointer to the value. + void *getAsOpaquePointer() { return impl.getOpaqueValue(); } + + /// Print this value to the provided output stream. + void print(raw_ostream &os); + +private: + /// The internal opaque representation of a PDLValue. We use a nested + /// PointerUnion structure here because `Value` only has 1 low bit + /// available, where as the remaining types all have 3. + llvm::PointerUnion<AttrOpTypeImplT, Value> impl; +}; + +inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { + value.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// PDLPatternModule + +/// A generic PDL pattern constraint function. This function applies a +/// constraint to a given set of opaque PDLValue entities. The second parameter +/// is a set of constant value parameters specified in Attribute form. Returns +/// success if the constraint successfully held, failure otherwise. +using PDLConstraintFunction = std::function<LogicalResult( + ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>; +/// A native PDL creation function. This function creates a new PDLValue given +/// a set of existing PDL values, a set of constant parameters specified in +/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue. +using PDLCreateFunction = + std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>; +/// A native PDL rewrite function. This function rewrites the given root +/// operation using the provided PatternRewriter. This method is only invoked +/// when the corresponding match was successful. +using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>, + ArrayAttr, PatternRewriter &)>; +/// A generic PDL pattern constraint function. This function applies a +/// constraint to a given opaque PDLValue entity. The second parameter is a set +/// of constant value parameters specified in Attribute form. Returns success if +/// the constraint successfully held, failure otherwise. +using PDLSingleEntityConstraintFunction = + std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>; + +/// This class contains all of the necessary data for a set of PDL patterns, or +/// pattern rewrites specified in the form of the PDL dialect. This PDL module +/// contained by this pattern may contain any number of `pdl.pattern` +/// operations. +class PDLPatternModule { +public: + PDLPatternModule() = default; + + /// Construct a PDL pattern with the given module. + PDLPatternModule(OwningModuleRef pdlModule) + : pdlModule(std::move(pdlModule)) {} + + /// Merge the state in `other` into this pattern module. + void mergeIn(PDLPatternModule &&other); + + /// Return the internal PDL module of this pattern. + ModuleOp getModule() { return pdlModule.get(); } + + //===--------------------------------------------------------------------===// + // Function Registry + + /// Register a constraint function. + void registerConstraintFunction(StringRef name, + PDLConstraintFunction constraintFn); + /// Register a single entity constraint function. + template <typename SingleEntityFn> + std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>, + ArrayAttr, PatternRewriter &>::value> + registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { + registerConstraintFunction(name, [=](ArrayRef<PDLValue> values, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + assert(values.size() == 1 && "expected values to have a single entity"); + return constraintFn(values[0], constantParams, rewriter); + }); + } + + /// Register a creation function. + void registerCreateFunction(StringRef name, PDLCreateFunction createFn); + + /// Register a rewrite function. + void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); + + /// Return the set of the registered constraint functions. + const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { + return constraintFunctions; + } + llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() { + return constraintFunctions; + } + /// Return the set of the registered create functions. + const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const { + return createFunctions; + } + llvm::StringMap<PDLCreateFunction> takeCreateFunctions() { + return createFunctions; + } + /// Return the set of the registered rewrite functions. + const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const { + return rewriteFunctions; + } + llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() { + return rewriteFunctions; + } + + /// Clear out the patterns and functions within this module. + void clear() { + pdlModule = nullptr; + constraintFunctions.clear(); + createFunctions.clear(); + rewriteFunctions.clear(); + } + +private: + /// The module containing the `pdl.pattern` operations. + OwningModuleRef pdlModule; + + /// The external functions referenced from within the PDL module. + llvm::StringMap<PDLConstraintFunction> constraintFunctions; + llvm::StringMap<PDLCreateFunction> createFunctions; + llvm::StringMap<PDLRewriteFunction> rewriteFunctions; +}; + +//===----------------------------------------------------------------------===// // PatternRewriter //===----------------------------------------------------------------------===// @@ -384,28 +568,28 @@ private: //===----------------------------------------------------------------------===// class OwningRewritePatternList { - using PatternListT = std::vector<std::unique_ptr<RewritePattern>>; + using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; public: OwningRewritePatternList() = default; - /// Construct a OwningRewritePatternList populated with the pattern `t` of - /// type `T`. - template <typename T> - OwningRewritePatternList(T &&t) { - patterns.emplace_back(std::make_unique<T>(std::forward<T>(t))); + /// Construct a OwningRewritePatternList populated with the given pattern. + OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) { + nativePatterns.emplace_back(std::move(pattern)); } + OwningRewritePatternList(PDLPatternModule &&pattern) + : pdlPatterns(std::move(pattern)) {} + + /// Return the native patterns held in this list. + NativePatternListT &getNativePatterns() { return nativePatterns; } - PatternListT::iterator begin() { return patterns.begin(); } - PatternListT::iterator end() { return patterns.end(); } - PatternListT::const_iterator begin() const { return patterns.begin(); } - PatternListT::const_iterator end() const { return patterns.end(); } - PatternListT::size_type size() const { return patterns.size(); } - void clear() { patterns.clear(); } + /// Return the PDL patterns held in this list. + PDLPatternModule &getPDLPatterns() { return pdlPatterns; } - /// Take ownership of the patterns held by this list. - std::vector<std::unique_ptr<RewritePattern>> takePatterns() { - return std::move(patterns); + /// Clear out all of the held patterns in this list. + void clear() { + nativePatterns.clear(); + pdlPatterns.clear(); } //===--------------------------------------------------------------------===// @@ -419,31 +603,53 @@ public: typename... ConstructorArgs, typename = std::enable_if_t<sizeof...(Ts) != 0>> OwningRewritePatternList &insert(ConstructorArg &&arg, - ConstructorArgs &&... args) { + ConstructorArgs &&...args) { // The following expands a call to emplace_back for each of the pattern // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list<int>{ - 0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...}; + (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...}; return *this; } /// Add an instance of each of the pattern types 'Ts'. Return a reference to /// `this` for chaining insertions. template <typename... Ts> OwningRewritePatternList &insert() { - (void)std::initializer_list<int>{ - 0, (patterns.emplace_back(std::make_unique<Ts>()), 0)...}; + (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...}; return *this; } - /// Add the given pattern to the pattern list. - void insert(std::unique_ptr<RewritePattern> pattern) { - patterns.emplace_back(std::move(pattern)); + /// Add the given native pattern to the pattern list. Return a reference to + /// `this` for chaining insertions. + OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) { + nativePatterns.emplace_back(std::move(pattern)); + return *this; + } + + /// Add the given PDL pattern to the pattern list. Return a reference to + /// `this` for chaining insertions. + OwningRewritePatternList &insert(PDLPatternModule &&pattern) { + pdlPatterns.mergeIn(std::move(pattern)); + return *this; } private: - PatternListT patterns; + /// Add an instance of the pattern type 'T'. Return a reference to `this` for + /// chaining insertions. + template <typename T, typename... Args> + std::enable_if_t<std::is_base_of<RewritePattern, T>::value> + insertImpl(Args &&...args) { + nativePatterns.emplace_back( + std::make_unique<T>(std::forward<Args>(args)...)); + } + template <typename T, typename... Args> + std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> + insertImpl(Args &&...args) { + pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); + } + + NativePatternListT nativePatterns; + PDLPatternModule pdlPatterns; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index c0096bb6b233..719bb1a62f97 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -104,6 +104,12 @@ public: return UniquerT::template get<ConcreteT>(loc.getContext(), args...); } + /// Get an instance of the concrete type from a void pointer. + static ConcreteT getFromOpaquePointer(const void *ptr) { + return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>() + : nullptr; + } + protected: /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h index fb2657d99232..c2335b9dd5a1 100644 --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -12,25 +12,40 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { +namespace detail { +class PDLByteCode; +} // end namespace detail + /// This class represents a frozen set of patterns that can be processed by a /// pattern applicator. This class is designed to enable caching pattern lists /// such that they need not be continuously recomputed. class FrozenRewritePatternList { - using PatternListT = std::vector<std::unique_ptr<RewritePattern>>; + using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; public: /// Freeze the patterns held in `patterns`, and take ownership. FrozenRewritePatternList(OwningRewritePatternList &&patterns); + FrozenRewritePatternList(FrozenRewritePatternList &&patterns); + ~FrozenRewritePatternList(); + + /// Return the native patterns held by this list. + iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>> + getNativePatterns() const { + return llvm::make_pointee_range(nativePatterns); + } - /// Return the patterns held by this list. - iterator_range<llvm::pointee_iterator<PatternListT::const_iterator>> - getPatterns() const { - return llvm::make_pointee_range(patterns); + /// Return the compiled PDL bytecode held by this list. Returns null if + /// there are no PDL patterns within the list. + const detail::PDLByteCode *getPDLByteCode() const { + return pdlByteCode.get(); } private: - /// The patterns held by this list. - std::vector<std::unique_ptr<RewritePattern>> patterns; + /// The set of. + std::vector<std::unique_ptr<RewritePattern>> nativePatterns; + + /// The bytecode containing the compiled PDL patterns. + std::unique_ptr<detail::PDLByteCode> pdlByteCode; }; } // end namespace mlir diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h index cb7794bab9fc..9d197175b47d 100644 --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -19,6 +19,10 @@ namespace mlir { class PatternRewriter; +namespace detail { +class PDLByteCodeMutableState; +} // end namespace detail + /// This class manages the application of a group of rewrite patterns, with a /// user-provided cost model. class PatternApplicator { @@ -29,8 +33,8 @@ public: /// `impossibleToMatch`. using CostModel = function_ref<PatternBenefit(const Pattern &)>; - explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList) - : frozenPatternList(frozenPatternList) {} + explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList); + ~PatternApplicator(); /// Attempt to match and rewrite the given op with any pattern, allowing a /// predicate to decide if a pattern can be applied or not, and hooks for if @@ -60,16 +64,6 @@ public: void walkAllPatterns(function_ref<void(const Pattern &)> walk); private: - /// Attempt to match and rewrite the given op with the given pattern, allowing - /// a predicate to decide if a pattern can be applied or not, and hooks for if - /// the pattern match was a success or failure. - LogicalResult - matchAndRewrite(Operation *op, const RewritePattern &pattern, - PatternRewriter &rewriter, - function_ref<bool(const Pattern &)> canApply, - function_ref<void(const Pattern &)> onFailure, - function_ref<LogicalResult(const Pattern &)> onSuccess); - /// The list that owns the patterns used within this applicator. const FrozenRewritePatternList &frozenPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. @@ -77,6 +71,8 @@ private: /// The set of patterns that may match against any operation type, stable /// sorted by benefit. SmallVector<const RewritePattern *, 1> anyOpPatterns; + /// The mutable state used during execution of the PDL bytecode. + std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState; }; } // end namespace mlir diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index b9ddabb80800..79e7daa12a7c 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -302,13 +302,15 @@ unsigned PredecessorIterator::getSuccessorIndex() const { // SuccessorRange //===----------------------------------------------------------------------===// -SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {} + +SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() { if (Operation *term = block->getTerminator()) if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } -SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index e725dd87d93f..3037bf082d58 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -61,8 +61,9 @@ const AbstractOperation *OperationName::getAbstractOperation() const { return representation.dyn_cast<const AbstractOperation *>(); } -OperationName OperationName::getFromOpaquePointer(void *pointer) { - return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); +OperationName OperationName::getFromOpaquePointer(const void *pointer) { + return OperationName( + RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer))); } //===----------------------------------------------------------------------===// @@ -484,6 +485,12 @@ void Operation::erase() { destroy(); } +/// Remove the operation from its parent block, but don't delete it. +void Operation::remove() { + if (Block *parent = getBlock()) + parent->getOperations().remove(this); +} + /// Unlink this operation from its current block and insert it right before /// `existingOp` which may be in the same or another block in the same /// function. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index edd5e7b9d6d7..6558fcf4606d 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -70,6 +70,84 @@ LogicalResult RewritePattern::match(Operation *op) const { void RewritePattern::anchor() {} //===----------------------------------------------------------------------===// +// PDLValue +//===----------------------------------------------------------------------===// + +void PDLValue::print(raw_ostream &os) { + if (!impl) { + os << "<Null-PDLValue>"; + return; + } + if (Value val = impl.dyn_cast<Value>()) { + os << val; + return; + } + AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>(); + if (Attribute attr = aotImpl.dyn_cast<Attribute>()) + os << attr; + else if (Operation *op = aotImpl.dyn_cast<Operation *>()) + os << *op; + else + os << aotImpl.get<Type>(); +} + +//===----------------------------------------------------------------------===// +// PDLPatternModule +//===----------------------------------------------------------------------===// + +void PDLPatternModule::mergeIn(PDLPatternModule &&other) { + // Ignore the other module if it has no patterns. + if (!other.pdlModule) + return; + // Steal the other state if we have no patterns. + if (!pdlModule) { + constraintFunctions = std::move(other.constraintFunctions); + createFunctions = std::move(other.createFunctions); + rewriteFunctions = std::move(other.rewriteFunctions); + pdlModule = std::move(other.pdlModule); + return; + } + // Steal the functions of the other module. + for (auto &it : constraintFunctions) + registerConstraintFunction(it.first(), std::move(it.second)); + for (auto &it : createFunctions) + registerCreateFunction(it.first(), std::move(it.second)); + for (auto &it : rewriteFunctions) + registerRewriteFunction(it.first(), std::move(it.second)); + + // Merge the pattern operations from the other module into this one. + Block *block = pdlModule->getBody(); + block->getTerminator()->erase(); + block->getOperations().splice(block->end(), + other.pdlModule->getBody()->getOperations()); +} + +//===----------------------------------------------------------------------===// +// Function Registry + +void PDLPatternModule::registerConstraintFunction( + StringRef name, PDLConstraintFunction constraintFn) { + auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); + (void)it; + assert(it.second && + "constraint with the given name has already been registered"); +} +void PDLPatternModule::registerCreateFunction(StringRef name, + PDLCreateFunction createFn) { + auto it = createFunctions.try_emplace(name, std::move(createFn)); + (void)it; + assert(it.second && "native create function with the given name has " + "already been registered"); +} +void PDLPatternModule::registerRewriteFunction(StringRef name, + PDLRewriteFunction rewriteFn) { + auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); + (void)it; + assert(it.second && "native rewrite function with the given name has " + "already been registered"); +} + +//===----------------------------------------------------------------------===// // PatternRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp new file mode 100644 index 000000000000..ae5f322d2948 --- /dev/null +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -0,0 +1,1262 @@ +//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements MLIR to byte-code generation and the interpreter. +// +//===----------------------------------------------------------------------===// + +#include "ByteCode.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "llvm/ADT/IntervalMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "pdl-bytecode" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +// PDLByteCodePattern +//===----------------------------------------------------------------------===// + +PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, + ByteCodeAddr rewriterAddr) { + SmallVector<StringRef, 8> generatedOps; + if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) + generatedOps = + llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); + + PatternBenefit benefit = matchOp.benefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Check to see if this is pattern matches a specific operation type. + if (Optional<StringRef> rootKind = matchOp.rootKind()) + return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, + ctx); + return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, + MatchAnyOpTypeTag()); +} + +//===----------------------------------------------------------------------===// +// PDLByteCodeMutableState +//===----------------------------------------------------------------------===// + +/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds +/// to the position of the pattern within the range returned by +/// `PDLByteCode::getPatterns`. +void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, + PatternBenefit benefit) { + currentPatternBenefits[patternIndex] = benefit; +} + +//===----------------------------------------------------------------------===// +// Bytecode OpCodes +//===----------------------------------------------------------------------===// + +namespace { +enum OpCode : ByteCodeField { + /// Apply an externally registered constraint. + ApplyConstraint, + /// Apply an externally registered rewrite. + ApplyRewrite, + /// Check if two generic values are equal. + AreEqual, + /// Unconditional branch. + Branch, + /// Compare the operand count of an operation with a constant. + CheckOperandCount, + /// Compare the name of an operation with a constant. + CheckOperationName, + /// Compare the result count of an operation with a constant. + CheckResultCount, + /// Invoke a native creation method. + CreateNative, + /// Create an operation. + CreateOperation, + /// Erase an operation. + EraseOp, + /// Terminate a matcher or rewrite sequence. + Finalize, + /// Get a specific attribute of an operation. + GetAttribute, + /// Get the type of an attribute. + GetAttributeType, + /// Get the defining operation of a value. + GetDefiningOp, + /// Get a specific operand of an operation. + GetOperand0, + GetOperand1, + GetOperand2, + GetOperand3, + GetOperandN, + /// Get a specific result of an operation. + GetResult0, + GetResult1, + GetResult2, + GetResult3, + GetResultN, + /// Get the type of a value. + GetValueType, + /// Check if a generic value is not null. + IsNotNull, + /// Record a successful pattern match. + RecordMatch, + /// Replace an operation. + ReplaceOp, + /// Compare an attribute with a set of constants. + SwitchAttribute, + /// Compare the operand count of an operation with a set of constants. + SwitchOperandCount, + /// Compare the name of an operation with a set of constants. + SwitchOperationName, + /// Compare the result count of an operation with a set of constants. + SwitchResultCount, + /// Compare a type with a set of constants. + SwitchType, +}; + +enum class PDLValueKind { Attribute, Operation, Type, Value }; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ByteCode Generation +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Generator + +namespace { +struct ByteCodeWriter; + +/// This class represents the main generator for the pattern bytecode. +class Generator { +public: + Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, + SmallVectorImpl<ByteCodeField> &matcherByteCode, + SmallVectorImpl<ByteCodeField> &rewriterByteCode, + SmallVectorImpl<PDLByteCodePattern> &patterns, + ByteCodeField &maxValueMemoryIndex, + llvm::StringMap<PDLConstraintFunction> &constraintFns, + llvm::StringMap<PDLCreateFunction> &createFns, + llvm::StringMap<PDLRewriteFunction> &rewriteFns) + : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), + rewriterByteCode(rewriterByteCode), patterns(patterns), + maxValueMemoryIndex(maxValueMemoryIndex) { + for (auto it : llvm::enumerate(constraintFns)) + constraintToMemIndex.try_emplace(it.value().first(), it.index()); + for (auto it : llvm::enumerate(createFns)) + nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); + for (auto it : llvm::enumerate(rewriteFns)) + externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); + } + + /// Generate the bytecode for the given PDL interpreter module. + void generate(ModuleOp module); + + /// Return the memory index to use for the given value. + ByteCodeField &getMemIndex(Value value) { + assert(valueToMemIndex.count(value) && + "expected memory index to be assigned"); + return valueToMemIndex[value]; + } + + /// Return an index to use when referring to the given data that is uniqued in + /// the MLIR context. + template <typename T> + std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> + getMemIndex(T val) { + const void *opaqueVal = val.getAsOpaquePointer(); + + // Get or insert a reference to this value. + auto it = uniquedDataToMemIndex.try_emplace( + opaqueVal, maxValueMemoryIndex + uniquedData.size()); + if (it.second) + uniquedData.push_back(opaqueVal); + return it.first->second; + } + +private: + /// Allocate memory indices for the results of operations within the matcher + /// and rewriters. + void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); + + /// Generate the bytecode for the given operation. + void generate(Operation *op, ByteCodeWriter &writer); + void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); + void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); + void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); + void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); + void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); + + /// Mapping from value to its corresponding memory index. + DenseMap<Value, ByteCodeField> valueToMemIndex; + + /// Mapping from the name of an externally registered rewrite to its index in + /// the bytecode registry. + llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; + + /// Mapping from the name of an externally registered constraint to its index + /// in the bytecode registry. + llvm::StringMap<ByteCodeField> constraintToMemIndex; + + /// Mapping from the name of an externally registered creation method to its + /// index in the bytecode registry. + llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; + + /// Mapping from rewriter function name to the bytecode address of the + /// rewriter function in byte. + llvm::StringMap<ByteCodeAddr> rewriterToAddr; + + /// Mapping from a uniqued storage object to its memory index within + /// `uniquedData`. + DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; + + /// The current MLIR context. + MLIRContext *ctx; + + /// Data of the ByteCode class to be populated. + std::vector<const void *> &uniquedData; + SmallVectorImpl<ByteCodeField> &matcherByteCode; + SmallVectorImpl<ByteCodeField> &rewriterByteCode; + SmallVectorImpl<PDLByteCodePattern> &patterns; + ByteCodeField &maxValueMemoryIndex; +}; + +/// This class provides utilities for writing a bytecode stream. +struct ByteCodeWriter { + ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) + : bytecode(bytecode), generator(generator) {} + + /// Append a field to the bytecode. + void append(ByteCodeField field) { bytecode.push_back(field); } + + /// Append an address to the bytecode. + void append(ByteCodeAddr field) { + static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, + "unexpected ByteCode address size"); + + ByteCodeField fieldParts[2]; + std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); + bytecode.append({fieldParts[0], fieldParts[1]}); + } + + /// Append a successor range to the bytecode, the exact address will need to + /// be resolved later. + void append(SuccessorRange successors) { + // Add back references to the any successors so that the address can be + // resolved later. + for (Block *successor : successors) { + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + } + + /// Append a range of values that will be read as generic PDLValues. + void appendPDLValueList(OperandRange values) { + bytecode.push_back(values.size()); + for (Value value : values) { + // Append the type of the value in addition to the value itself. + PDLValueKind kind = + TypeSwitch<Type, PDLValueKind>(value.getType()) + .Case<pdl::AttributeType>( + [](Type) { return PDLValueKind::Attribute; }) + .Case<pdl::OperationType>( + [](Type) { return PDLValueKind::Operation; }) + .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) + .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); + bytecode.push_back(static_cast<ByteCodeField>(kind)); + append(value); + } + } + + /// Check if the given class `T` has an iterator type. + template <typename T, typename... Args> + using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); + + /// Append a value that will be stored in a memory slot and not inline within + /// the bytecode. + template <typename T> + std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || + std::is_pointer<T>::value> + append(T value) { + bytecode.push_back(generator.getMemIndex(value)); + } + + /// Append a range of values. + template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> + std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> + append(T range) { + bytecode.push_back(llvm::size(range)); + for (auto it : range) + append(it); + } + + /// Append a variadic number of fields to the bytecode. + template <typename FieldTy, typename Field2Ty, typename... FieldTys> + void append(FieldTy field, Field2Ty field2, FieldTys... fields) { + append(field); + append(field2, fields...); + } + + /// Successor references in the bytecode that have yet to be resolved. + DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; + + /// The underlying bytecode buffer. + SmallVectorImpl<ByteCodeField> &bytecode; + + /// The main generator producing PDL. + Generator &generator; +}; +} // end anonymous namespace + +void Generator::generate(ModuleOp module) { + FuncOp matcherFunc = module.lookupSymbol<FuncOp>( + pdl_interp::PDLInterpDialect::getMatcherFunctionName()); + ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( + pdl_interp::PDLInterpDialect::getRewriterModuleName()); + assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); + + // Allocate memory indices for the results of operations within the matcher + // and rewriters. + allocateMemoryIndices(matcherFunc, rewriterModule); + + // Generate code for the rewriter functions. + ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); + for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { + rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); + for (Operation &op : rewriterFunc.getOps()) + generate(&op, rewriterByteCodeWriter); + } + assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && + "unexpected branches in rewriter function"); + + // Generate code for the matcher function. + DenseMap<Block *, ByteCodeAddr> blockToAddr; + llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); + ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) + generate(&op, matcherByteCodeWriter); + } + + // Resolve successor references in the matcher. + for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { + ByteCodeAddr addr = blockToAddr[it.first]; + for (unsigned offsetToFix : it.second) + std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); + } +} + +void Generator::allocateMemoryIndices(FuncOp matcherFunc, + ModuleOp rewriterModule) { + // Rewriters use simplistic allocation scheme that simply assigns an index to + // each result. + for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { + ByteCodeField index = 0; + for (BlockArgument arg : rewriterFunc.getArguments()) + valueToMemIndex.try_emplace(arg, index++); + rewriterFunc.getBody().walk([&](Operation *op) { + for (Value result : op->getResults()) + valueToMemIndex.try_emplace(result, index++); + }); + if (index > maxValueMemoryIndex) + maxValueMemoryIndex = index; + } + + // The matcher function uses a more sophisticated numbering that tries to + // minimize the number of memory indices assigned. This is done by determining + // a live range of the values within the matcher, then the allocation is just + // finding the minimal number of overlapping live ranges. This is essentially + // a simplified form of register allocation where we don't necessarily have a + // limited number of registers, but we still want to minimize the number used. + DenseMap<Operation *, ByteCodeField> opToIndex; + matcherFunc.getBody().walk([&](Operation *op) { + opToIndex.insert(std::make_pair(op, opToIndex.size())); + }); + + // Liveness info for each of the defs within the matcher. + using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; + LivenessSet::Allocator allocator; + DenseMap<Value, LivenessSet> valueDefRanges; + + // Assign the root operation being matched to slot 0. + BlockArgument rootOpArg = matcherFunc.getArgument(0); + valueToMemIndex[rootOpArg] = 0; + + // Walk each of the blocks, computing the def interval that the value is used. + Liveness matcherLiveness(matcherFunc); + for (Block &block : matcherFunc.getBody()) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + assert(info && "expected liveness info for block"); + auto processValue = [&](Value value, Operation *firstUseOrDef) { + // We don't need to process the root op argument, this value is always + // assigned to the first memory slot. + if (value == rootOpArg) + return; + + // Set indices for the range of this block that the value is used. + auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; + defRangeIt->second.insert( + opToIndex[firstUseOrDef], + opToIndex[info->getEndOperation(value, firstUseOrDef)], + /*dummyValue*/ 0); + }; + + // Process the live-ins of this block. + for (Value liveIn : info->in()) + processValue(liveIn, &block.front()); + + // Process any new defs within this block. + for (Operation &op : block) + for (Value result : op.getResults()) + processValue(result, &op); + } + + // Greedily allocate memory slots using the computed def live ranges. + std::vector<LivenessSet> allocatedIndices; + for (auto &defIt : valueDefRanges) { + ByteCodeField &memIndex = valueToMemIndex[defIt.first]; + LivenessSet &defSet = defIt.second; + + // Try to allocate to an existing index. + for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { + LivenessSet &existingIndex = existingIndexIt.value(); + llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( + defIt.second, existingIndex); + if (overlaps.valid()) + continue; + // Union the range of the def within the existing index. + for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) + existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); + memIndex = existingIndexIt.index() + 1; + } + + // If no existing index could be used, add a new one. + if (memIndex == 0) { + allocatedIndices.emplace_back(allocator); + for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) + allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); + memIndex = allocatedIndices.size(); + } + } + + // Update the max number of indices. + ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; + if (numMatcherIndices > maxValueMemoryIndex) + maxValueMemoryIndex = numMatcherIndices; +} + +void Generator::generate(Operation *op, ByteCodeWriter &writer) { + TypeSwitch<Operation *>(op) + .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, + pdl_interp::AreEqualOp, pdl_interp::BranchOp, + pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, + pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, + pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, + pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, + pdl_interp::CreateTypeOp, pdl_interp::EraseOp, + pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, + pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, + pdl_interp::GetOperandOp, pdl_interp::GetResultOp, + pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, + pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, + pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, + pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, + pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( + [&](auto interpOp) { this->generate(interpOp, writer); }) + .Default([](Operation *) { + llvm_unreachable("unknown `pdl_interp` operation"); + }); +} + +void Generator::generate(pdl_interp::ApplyConstraintOp op, + ByteCodeWriter &writer) { + assert(constraintToMemIndex.count(op.name()) && + "expected index for constraint function"); + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], + op.constParamsAttr()); + writer.appendPDLValueList(op.args()); + writer.append(op.getSuccessors()); +} +void Generator::generate(pdl_interp::ApplyRewriteOp op, + ByteCodeWriter &writer) { + assert(externalRewriterToMemIndex.count(op.name()) && + "expected index for rewrite function"); + writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], + op.constParamsAttr(), op.root()); + writer.appendPDLValueList(op.args()); +} +void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { + writer.append(OpCode::Branch, SuccessorRange(op)); +} +void Generator::generate(pdl_interp::CheckAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckOperandCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckOperationNameOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckOperationName, op.operation(), + OperationName(op.name(), ctx), op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckResultCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckResultCount, op.operation(), op.count(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::CreateAttributeOp op, + ByteCodeWriter &writer) { + // Simply repoint the memory index of the result to the constant. + getMemIndex(op.attribute()) = getMemIndex(op.value()); +} +void Generator::generate(pdl_interp::CreateNativeOp op, + ByteCodeWriter &writer) { + assert(nativeCreateToMemIndex.count(op.name()) && + "expected index for creation function"); + writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], + op.result(), op.constParamsAttr()); + writer.appendPDLValueList(op.args()); +} +void Generator::generate(pdl_interp::CreateOperationOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CreateOperation, op.operation(), + OperationName(op.name(), ctx), op.operands()); + + // Add the attributes. + OperandRange attributes = op.attributes(); + writer.append(static_cast<ByteCodeField>(attributes.size())); + for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { + writer.append( + Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), + std::get<1>(it)); + } + writer.append(op.types()); +} +void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { + // Simply repoint the memory index of the result to the constant. + getMemIndex(op.result()) = getMemIndex(op.value()); +} +void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { + writer.append(OpCode::EraseOp, op.operation()); +} +void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::Finalize); +} +void Generator::generate(pdl_interp::GetAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), + Identifier::get(op.name(), ctx)); +} +void Generator::generate(pdl_interp::GetAttributeTypeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetAttributeType, op.result(), op.value()); +} +void Generator::generate(pdl_interp::GetDefiningOpOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { + uint32_t index = op.index(); + if (index < 4) + writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); + else + writer.append(OpCode::GetOperandN, index); + writer.append(op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { + uint32_t index = op.index(); + if (index < 4) + writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); + else + writer.append(OpCode::GetResultN, index); + writer.append(op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetValueTypeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetValueType, op.result(), op.value()); +} +void Generator::generate(pdl_interp::InferredTypeOp op, + ByteCodeWriter &writer) { + // InferType maps to a null type as a marker for inferring a result type. + getMemIndex(op.type()) = getMemIndex(Type()); +} +void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { + writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { + ByteCodeField patternIndex = patterns.size(); + patterns.emplace_back(PDLByteCodePattern::create( + op, rewriterToAddr[op.rewriter().getLeafReference()])); + writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op), + op.matchedOps(), op.inputs()); +} +void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { + writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); +} +void Generator::generate(pdl_interp::SwitchAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchOperandCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchOperationNameOp op, + ByteCodeWriter &writer) { + auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { + return OperationName(attr.cast<StringAttr>().getValue(), ctx); + }); + writer.append(OpCode::SwitchOperationName, op.operation(), cases, + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchResultCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), + op.getSuccessors()); +} + +//===----------------------------------------------------------------------===// +// PDLByteCode +//===----------------------------------------------------------------------===// + +PDLByteCode::PDLByteCode(ModuleOp module, + llvm::StringMap<PDLConstraintFunction> constraintFns, + llvm::StringMap<PDLCreateFunction> createFns, + llvm::StringMap<PDLRewriteFunction> rewriteFns) { + Generator generator(module.getContext(), uniquedData, matcherByteCode, + rewriterByteCode, patterns, maxValueMemoryIndex, + constraintFns, createFns, rewriteFns); + generator.generate(module); + + // Initialize the external functions. + for (auto &it : constraintFns) + constraintFunctions.push_back(std::move(it.second)); + for (auto &it : createFns) + createFunctions.push_back(std::move(it.second)); + for (auto &it : rewriteFns) + rewriteFunctions.push_back(std::move(it.second)); +} + +/// Initialize the given state such that it can be used to execute the current +/// bytecode. +void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { + state.memory.resize(maxValueMemoryIndex, nullptr); + state.currentPatternBenefits.reserve(patterns.size()); + for (const PDLByteCodePattern &pattern : patterns) + state.currentPatternBenefits.push_back(pattern.getBenefit()); +} + +//===----------------------------------------------------------------------===// +// ByteCode Execution + +namespace { +/// This class provides support for executing a bytecode stream. +class ByteCodeExecutor { +public: + ByteCodeExecutor(const ByteCodeField *curCodeIt, + MutableArrayRef<const void *> memory, + ArrayRef<const void *> uniquedMemory, + ArrayRef<ByteCodeField> code, + ArrayRef<PatternBenefit> currentPatternBenefits, + ArrayRef<PDLByteCodePattern> patterns, + ArrayRef<PDLConstraintFunction> constraintFunctions, + ArrayRef<PDLCreateFunction> createFunctions, + ArrayRef<PDLRewriteFunction> rewriteFunctions) + : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), + code(code), currentPatternBenefits(currentPatternBenefits), + patterns(patterns), constraintFunctions(constraintFunctions), + createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} + + /// Start executing the code at the current bytecode index. `matches` is an + /// optional field provided when this function is executed in a matching + /// context. + void execute(PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, + Optional<Location> mainRewriteLoc = {}); + +private: + /// Read a value from the bytecode buffer, optionally skipping a certain + /// number of prefix values. These methods always update the buffer to point + /// to the next field after the read data. + template <typename T = ByteCodeField> + T read(size_t skipN = 0) { + curCodeIt += skipN; + return readImpl<T>(); + } + ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } + + /// Read a list of values from the bytecode buffer. + template <typename ValueT, typename T> + void readList(SmallVectorImpl<T> &list) { + list.clear(); + for (unsigned i = 0, e = read(); i != e; ++i) + list.push_back(read<ValueT>()); + } + + /// Jump to a specific successor based on a predicate value. + void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } + /// Jump to a specific successor based on a destination index. + void selectJump(size_t destIndex) { + curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; + } + + /// Handle a switch operation with the provided value and cases. + template <typename T, typename RangeT> + void handleSwitch(const T &value, RangeT &&cases) { + LLVM_DEBUG({ + llvm::dbgs() << " * Value: " << value << "\n" + << " * Cases: "; + llvm::interleaveComma(cases, llvm::dbgs()); + llvm::dbgs() << "\n\n"; + }); + + // Check to see if the attribute value is within the case list. Jump to + // the correct successor index based on the result. + auto it = llvm::find(cases, value); + selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1)); + } + + /// Internal implementation of reading various data types from the bytecode + /// stream. + template <typename T> + const void *readFromMemory() { + size_t index = *curCodeIt++; + + // If this type is an SSA value, it can only be stored in non-const memory. + if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) + return memory[index]; + + // Otherwise, if this index is not inbounds it is uniqued. + return uniquedMemory[index - memory.size()]; + } + template <typename T> + std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { + return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); + } + template <typename T> + std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, + T> + readImpl() { + return T(T::getFromOpaquePointer(readFromMemory<T>())); + } + template <typename T> + std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { + switch (static_cast<PDLValueKind>(read())) { + case PDLValueKind::Attribute: + return read<Attribute>(); + case PDLValueKind::Operation: + return read<Operation *>(); + case PDLValueKind::Type: + return read<Type>(); + case PDLValueKind::Value: + return read<Value>(); + } + } + template <typename T> + std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { + static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, + "unexpected ByteCode address size"); + ByteCodeAddr result; + std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); + curCodeIt += 2; + return result; + } + template <typename T> + std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { + return *curCodeIt++; + } + + /// The underlying bytecode buffer. + const ByteCodeField *curCodeIt; + + /// The current execution memory. + MutableArrayRef<const void *> memory; + + /// References to ByteCode data necessary for execution. + ArrayRef<const void *> uniquedMemory; + ArrayRef<ByteCodeField> code; + ArrayRef<PatternBenefit> currentPatternBenefits; + ArrayRef<PDLByteCodePattern> patterns; + ArrayRef<PDLConstraintFunction> constraintFunctions; + ArrayRef<PDLCreateFunction> createFunctions; + ArrayRef<PDLRewriteFunction> rewriteFunctions; +}; +} // end anonymous namespace + +void ByteCodeExecutor::execute( + PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches, + Optional<Location> mainRewriteLoc) { + while (true) { + OpCode opCode = static_cast<OpCode>(read()); + switch (opCode) { + case ApplyConstraint: { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); + const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ArrayAttr constParams = read<ArrayAttr>(); + SmallVector<PDLValue, 16> args; + readList<PDLValue>(args); + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; + }); + + // Invoke the constraint and jump to the proper destination. + selectJump(succeeded(constraintFn(args, constParams, rewriter))); + break; + } + case ApplyRewrite: { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); + const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; + ArrayAttr constParams = read<ArrayAttr>(); + Operation *root = read<Operation *>(); + SmallVector<PDLValue, 16> args; + readList<PDLValue>(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Root: " << *root << "\n" + << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; + }); + rewriteFn(root, args, constParams, rewriter); + break; + } + case AreEqual: { + LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + const void *lhs = read<const void *>(); + const void *rhs = read<const void *>(); + + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(lhs == rhs); + break; + } + case Branch: { + LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); + curCodeIt = &code[read<ByteCodeAddr>()]; + break; + } + case CheckOperandCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); + Operation *op = read<Operation *>(); + uint32_t expectedCount = read<uint32_t>(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" + << " * Expected: " << expectedCount << "\n\n"); + selectJump(op->getNumOperands() == expectedCount); + break; + } + case CheckOperationName: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); + Operation *op = read<Operation *>(); + OperationName expectedName = read<OperationName>(); + + LLVM_DEBUG(llvm::dbgs() + << " * Found: \"" << op->getName() << "\"\n" + << " * Expected: \"" << expectedName << "\"\n\n"); + selectJump(op->getName() == expectedName); + break; + } + case CheckResultCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); + Operation *op = read<Operation *>(); + uint32_t expectedCount = read<uint32_t>(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" + << " * Expected: " << expectedCount << "\n\n"); + selectJump(op->getNumResults() == expectedCount); + break; + } + case CreateNative: { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); + const PDLCreateFunction &createFn = createFunctions[read()]; + ByteCodeField resultIndex = read(); + ArrayAttr constParams = read<ArrayAttr>(); + SmallVector<PDLValue, 16> args; + readList<PDLValue>(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; + }); + + PDLValue result = createFn(args, constParams, rewriter); + memory[resultIndex] = result.getAsOpaquePointer(); + + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); + break; + } + case CreateOperation: { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); + assert(mainRewriteLoc && "expected rewrite loc to be provided when " + "executing the rewriter bytecode"); + + unsigned memIndex = read(); + OperationState state(*mainRewriteLoc, read<OperationName>()); + readList<Value>(state.operands); + for (unsigned i = 0, e = read(); i != e; ++i) { + Identifier name = read<Identifier>(); + if (Attribute attr = read<Attribute>()) + state.addAttribute(name, attr); + } + + bool hasInferredTypes = false; + for (unsigned i = 0, e = read(); i != e; ++i) { + Type resultType = read<Type>(); + hasInferredTypes |= !resultType; + state.types.push_back(resultType); + } + + // Handle the case where the operation has inferred types. + if (hasInferredTypes) { + InferTypeOpInterface::Concept *concept = + state.name.getAbstractOperation() + ->getInterface<InferTypeOpInterface>(); + + // TODO: Handle failure. + SmallVector<Type, 2> inferredTypes; + if (failed(concept->inferReturnTypes( + state.getContext(), state.location, state.operands, + state.attributes.getDictionary(state.getContext()), + state.regions, inferredTypes))) + return; + + for (unsigned i = 0, e = state.types.size(); i != e; ++i) + if (!state.types[i]) + state.types[i] = inferredTypes[i]; + } + Operation *resultOp = rewriter.createOperation(state); + memory[memIndex] = resultOp; + + LLVM_DEBUG({ + llvm::dbgs() << " * Attributes: " + << state.attributes.getDictionary(state.getContext()) + << "\n * Operands: "; + llvm::interleaveComma(state.operands, llvm::dbgs()); + llvm::dbgs() << "\n * Result Types: "; + llvm::interleaveComma(state.types, llvm::dbgs()); + llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; + }); + break; + } + case EraseOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); + Operation *op = read<Operation *>(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); + rewriter.eraseOp(op); + break; + } + case Finalize: { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + return; + } + case GetAttribute: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); + unsigned memIndex = read(); + Operation *op = read<Operation *>(); + Identifier attrName = read<Identifier>(); + Attribute attr = op->getAttr(attrName); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Attribute: " << attrName << "\n" + << " * Result: " << attr << "\n\n"); + memory[memIndex] = attr.getAsOpaquePointer(); + break; + } + case GetAttributeType: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); + unsigned memIndex = read(); + Attribute attr = read<Attribute>(); + + LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" + << " * Result: " << attr.getType() << "\n\n"); + memory[memIndex] = attr.getType().getAsOpaquePointer(); + break; + } + case GetDefiningOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); + unsigned memIndex = read(); + Value value = read<Value>(); + Operation *op = value ? value.getDefiningOp() : nullptr; + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << *op << "\n\n"); + memory[memIndex] = op; + break; + } + case GetOperand0: + case GetOperand1: + case GetOperand2: + case GetOperand3: + case GetOperandN: { + LLVM_DEBUG({ + llvm::dbgs() << "Executing GetOperand" + << (opCode == GetOperandN ? Twine("N") + : Twine(opCode - GetOperand0)) + << ":\n"; + }); + unsigned index = + opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0); + Operation *op = read<Operation *>(); + unsigned memIndex = read(); + Value operand = + index < op->getNumOperands() ? op->getOperand(index) : Value(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << operand << "\n\n"); + memory[memIndex] = operand.getAsOpaquePointer(); + break; + } + case GetResult0: + case GetResult1: + case GetResult2: + case GetResult3: + case GetResultN: { + LLVM_DEBUG({ + llvm::dbgs() << "Executing GetResult" + << (opCode == GetResultN ? Twine("N") + : Twine(opCode - GetResult0)) + << ":\n"; + }); + unsigned index = + opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0); + Operation *op = read<Operation *>(); + unsigned memIndex = read(); + OpResult result = + index < op->getNumResults() ? op->getResult(index) : OpResult(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << result << "\n\n"); + memory[memIndex] = result.getAsOpaquePointer(); + break; + } + case GetValueType: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); + unsigned memIndex = read(); + Value value = read<Value>(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << value.getType() << "\n\n"); + memory[memIndex] = value.getType().getAsOpaquePointer(); + break; + } + case IsNotNull: { + LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); + const void *value = read<const void *>(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); + selectJump(value != nullptr); + break; + } + case RecordMatch: { + LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); + assert(matches && + "expected matches to be provided when executing the matcher"); + unsigned patternIndex = read(); + PatternBenefit benefit = currentPatternBenefits[patternIndex]; + const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; + + // If the benefit of the pattern is impossible, skip the processing of the + // rest of the pattern. + if (benefit.isImpossibleToMatch()) { + LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); + curCodeIt = dest; + break; + } + + // Create a fused location containing the locations of each of the + // operations used in the match. This will be used as the location for + // created operations during the rewrite that don't already have an + // explicit location set. + unsigned numMatchLocs = read(); + SmallVector<Location, 4> matchLocs; + matchLocs.reserve(numMatchLocs); + for (unsigned i = 0; i != numMatchLocs; ++i) + matchLocs.push_back(read<Operation *>()->getLoc()); + Location matchLoc = rewriter.getFusedLoc(matchLocs); + + LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" + << " * Location: " << matchLoc << "\n\n"); + matches->emplace_back(matchLoc, patterns[patternIndex], benefit); + readList<const void *>(matches->back().values); + curCodeIt = dest; + break; + } + case ReplaceOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); + Operation *op = read<Operation *>(); + SmallVector<Value, 16> args; + readList<Value>(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Values: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n\n"; + }); + rewriter.replaceOp(op, args); + break; + } + case SwitchAttribute: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); + Attribute value = read<Attribute>(); + ArrayAttr cases = read<ArrayAttr>(); + handleSwitch(value, cases); + break; + } + case SwitchOperandCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); + Operation *op = read<Operation *>(); + auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumOperands(), cases); + break; + } + case SwitchOperationName: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); + OperationName value = read<Operation *>()->getName(); + size_t caseCount = read(); + + // The operation names are stored in-line, so to print them out for + // debugging purposes we need to read the array before executing the + // switch so that we can display all of the possible values. + LLVM_DEBUG({ + const ByteCodeField *prevCodeIt = curCodeIt; + llvm::dbgs() << " * Value: " << value << "\n" + << " * Cases: "; + llvm::interleaveComma( + llvm::map_range(llvm::seq<size_t>(0, caseCount), + [&](size_t i) { return read<OperationName>(); }), + llvm::dbgs()); + llvm::dbgs() << "\n\n"; + curCodeIt = prevCodeIt; + }); + + // Try to find the switch value within any of the cases. + size_t jumpDest = 0; + for (size_t i = 0; i != caseCount; ++i) { + if (read<OperationName>() == value) { + curCodeIt += (caseCount - i - 1); + jumpDest = i + 1; + break; + } + } + selectJump(jumpDest); + break; + } + case SwitchResultCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); + Operation *op = read<Operation *>(); + auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumResults(), cases); + break; + } + case SwitchType: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); + Type value = read<Type>(); + auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); + handleSwitch(value, cases); + break; + } + } + } +} + +/// Run the pattern matcher on the given root operation, collecting the matched +/// patterns in `matches`. +void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl<MatchResult> &matches, + PDLByteCodeMutableState &state) const { + // The first memory slot is always the root operation. + state.memory[0] = op; + + // The matcher function always starts at code address 0. + ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, + matcherByteCode, state.currentPatternBenefits, + patterns, constraintFunctions, createFunctions, + rewriteFunctions); + executor.execute(rewriter, &matches); + + // Order the found matches by benefit. + std::stable_sort(matches.begin(), matches.end(), + [](const MatchResult &lhs, const MatchResult &rhs) { + return lhs.benefit > rhs.benefit; + }); +} + +/// Run the rewriter of the given pattern on the root operation `op`. +void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const { + // The arguments of the rewrite function are stored at the start of the + // memory buffer. + llvm::copy(match.values, state.memory.begin()); + + ByteCodeExecutor executor( + &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, + uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, + constraintFunctions, createFunctions, rewriteFunctions); + executor.execute(rewriter, /*matches=*/nullptr, match.location); +} diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h new file mode 100644 index 000000000000..7126037f864a --- /dev/null +++ b/mlir/lib/Rewrite/ByteCode.h @@ -0,0 +1,173 @@ +//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a byte-code and interpreter for pattern rewrites in MLIR. +// The byte-code is constructed from the PDL Interpreter dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_BYTECODE_H_ +#define MLIR_REWRITE_BYTECODE_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace pdl_interp { +class RecordMatchOp; +} // end namespace pdl_interp + +namespace detail { +class PDLByteCode; + +/// Use generic bytecode types. ByteCodeField refers to the actual bytecode +/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of +/// indices into the bytecode. Correctness is checked with static asserts. +using ByteCodeField = uint16_t; +using ByteCodeAddr = uint32_t; + +//===----------------------------------------------------------------------===// +// PDLByteCodePattern +//===----------------------------------------------------------------------===// + +/// All of the data pertaining to a specific pattern within the bytecode. +class PDLByteCodePattern : public Pattern { +public: + static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, + ByteCodeAddr rewriterAddr); + + /// Return the bytecode address of the rewriter for this pattern. + ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } + +private: + template <typename... Args> + PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) + : Pattern(std::forward<Args>(patternArgs)...), + rewriterAddr(rewriterAddr) {} + + /// The address of the rewriter for this pattern. + ByteCodeAddr rewriterAddr; +}; + +//===----------------------------------------------------------------------===// +// PDLByteCodeMutableState +//===----------------------------------------------------------------------===// + +/// This class contains the mutable state of a bytecode instance. This allows +/// for a bytecode instance to be cached and reused across various different +/// threads/drivers. +class PDLByteCodeMutableState { +public: + /// Initialize the state from a bytecode instance. + void initialize(PDLByteCode &bytecode); + + /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds + /// to the position of the pattern within the range returned by + /// `PDLByteCode::getPatterns`. + void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); + +private: + /// Allow access to data fields. + friend class PDLByteCode; + + /// The mutable block of memory used during the matching and rewriting phases + /// of the bytecode. + std::vector<const void *> memory; + + /// The up-to-date benefits of the patterns held by the bytecode. The order + /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. + std::vector<PatternBenefit> currentPatternBenefits; +}; + +//===----------------------------------------------------------------------===// +// PDLByteCode +//===----------------------------------------------------------------------===// + +/// The bytecode class is also the interpreter. Contains the bytecode itself, +/// the static info, addresses of the rewriter functions, the interpreter +/// memory buffer, and the execution context. +class PDLByteCode { +public: + /// Each successful match returns a MatchResult, which contains information + /// necessary to execute the rewriter and indicates the originating pattern. + struct MatchResult { + MatchResult(Location loc, const PDLByteCodePattern &pattern, + PatternBenefit benefit) + : location(loc), pattern(&pattern), benefit(benefit) {} + + /// The location of operations to be replaced. + Location location; + /// Memory values defined in the matcher that are passed to the rewriter. + SmallVector<const void *, 4> values; + /// The originating pattern that was matched. This is always non-null, but + /// represented with a pointer to allow for assignment. + const PDLByteCodePattern *pattern; + /// The current benefit of the pattern that was matched. + PatternBenefit benefit; + }; + + /// Create a ByteCode instance from the given module containing operations in + /// the PDL interpreter dialect. + PDLByteCode(ModuleOp module, + llvm::StringMap<PDLConstraintFunction> constraintFns, + llvm::StringMap<PDLCreateFunction> createFns, + llvm::StringMap<PDLRewriteFunction> rewriteFns); + + /// Return the patterns held by the bytecode. + ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } + + /// Initialize the given state such that it can be used to execute the current + /// bytecode. + void initializeMutableState(PDLByteCodeMutableState &state) const; + + /// Run the pattern matcher on the given root operation, collecting the + /// matched patterns in `matches`. + void match(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl<MatchResult> &matches, + PDLByteCodeMutableState &state) const; + + /// Run the rewriter of the given pattern that was previously matched in + /// `match`. + void rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const; + +private: + /// Execute the given byte code starting at the provided instruction `inst`. + /// `matches` is an optional field provided when this function is executed in + /// a matching context. + void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, + PDLByteCodeMutableState &state, + SmallVectorImpl<MatchResult> *matches) const; + + /// A vector containing pointers to unqiued data. The storage is intentionally + /// opaque such that we can store a wide range of data types. The types of + /// data stored here include: + /// * Attribute, Identifier, OperationName, Type + std::vector<const void *> uniquedData; + + /// A vector containing the generated bytecode for the matcher. + SmallVector<ByteCodeField, 64> matcherByteCode; + + /// A vector containing the generated bytecode for all of the rewriters. + SmallVector<ByteCodeField, 64> rewriterByteCode; + + /// The set of patterns contained within the bytecode. + SmallVector<PDLByteCodePattern, 32> patterns; + + /// A set of user defined functions invoked via PDL. + std::vector<PDLConstraintFunction> constraintFunctions; + std::vector<PDLCreateFunction> createFunctions; + std::vector<PDLRewriteFunction> rewriteFunctions; + + /// The maximum memory index used by a value. + ByteCodeField maxValueMemoryIndex = 0; +}; + +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_REWRITE_BYTECODE_H_ diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt index e37b9c31dab9..5822789cc916 100644 --- a/mlir/lib/Rewrite/CMakeLists.txt +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRRewrite + ByteCode.cpp FrozenRewritePatternList.cpp PatternApplicator.cpp @@ -10,4 +11,8 @@ add_mlir_library(MLIRRewrite LINK_LIBS PUBLIC MLIRIR + MLIRPDL + MLIRPDLInterp + MLIRPDLToPDLInterp + MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp index d0e45184ac28..60f6dcea88f2 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -7,13 +7,71 @@ //===----------------------------------------------------------------------===// #include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "ByteCode.h" +#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" using namespace mlir; +static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { + // Skip the conversion if the module doesn't contain pdl. + if (llvm::empty(pdlModule.getOps<pdl::PatternOp>())) + return success(); + + // Simplify the provided PDL module. Note that we can't use the canonicalizer + // here because it would create a cyclic dependency. + auto simplifyFn = [](Operation *op) { + // TODO: Add folding here if ever necessary. + if (isOpTriviallyDead(op)) + op->erase(); + }; + pdlModule.getBody()->walk(simplifyFn); + + /// Lower the PDL pattern module to the interpreter dialect. + PassManager pdlPipeline(pdlModule.getContext()); +#ifdef NDEBUG + // We don't want to incur the hit of running the verifier when in release + // mode. + pdlPipeline.enableVerifier(false); +#endif + pdlPipeline.addPass(createPDLToPDLInterpPass()); + if (failed(pdlPipeline.run(pdlModule))) + return failure(); + + // Simplify again after running the lowering pipeline. + pdlModule.getBody()->walk(simplifyFn); + return success(); +} + //===----------------------------------------------------------------------===// // FrozenRewritePatternList //===----------------------------------------------------------------------===// FrozenRewritePatternList::FrozenRewritePatternList( OwningRewritePatternList &&patterns) - : patterns(patterns.takePatterns()) {} + : nativePatterns(std::move(patterns.getNativePatterns())) { + PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); + + // Generate the bytecode for the PDL patterns if any were provided. + ModuleOp pdlModule = pdlPatterns.getModule(); + if (!pdlModule) + return; + if (failed(convertPDLToPDLInterp(pdlModule))) + llvm::report_fatal_error( + "failed to lower PDL pattern module to the PDL Interpreter"); + + // Generate the pdl bytecode. + pdlByteCode = std::make_unique<detail::PDLByteCode>( + pdlModule, pdlPatterns.takeConstraintFunctions(), + pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); +} + +FrozenRewritePatternList::FrozenRewritePatternList( + FrozenRewritePatternList &&patterns) + : nativePatterns(std::move(patterns.nativePatterns)), + pdlByteCode(std::move(patterns.pdlByteCode)) {} + +FrozenRewritePatternList::~FrozenRewritePatternList() {} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index 5d6ae51e8eeb..6f5e1f299f26 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -12,17 +12,36 @@ //===----------------------------------------------------------------------===// #include "mlir/Rewrite/PatternApplicator.h" +#include "ByteCode.h" #include "llvm/Support/Debug.h" using namespace mlir; +using namespace mlir::detail; + +PatternApplicator::PatternApplicator( + const FrozenRewritePatternList &frozenPatternList) + : frozenPatternList(frozenPatternList) { + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); + bytecode->initializeMutableState(*mutableByteCodeState); + } +} +PatternApplicator::~PatternApplicator() {} #define DEBUG_TYPE "pattern-match" void PatternApplicator::applyCostModel(CostModel model) { + // Apply the cost model to the bytecode patterns first, and then the native + // patterns. + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + for (auto it : llvm::enumerate(bytecode->getPatterns())) + mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); + } + // Separate patterns by root kind to simplify lookup later on. patterns.clear(); anyOpPatterns.clear(); - for (const auto &pat : frozenPatternList.getPatterns()) { + for (const auto &pat : frozenPatternList.getNativePatterns()) { // If the pattern is always impossible to match, just ignore it. if (pat.getBenefit().isImpossibleToMatch()) { LLVM_DEBUG({ @@ -81,8 +100,12 @@ void PatternApplicator::applyCostModel(CostModel model) { void PatternApplicator::walkAllPatterns( function_ref<void(const Pattern &)> walk) { - for (auto &it : frozenPatternList.getPatterns()) + for (const Pattern &it : frozenPatternList.getNativePatterns()) walk(it); + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + for (const Pattern &it : bytecode->getPatterns()) + walk(it); + } } LogicalResult PatternApplicator::matchAndRewrite( @@ -90,6 +113,14 @@ LogicalResult PatternApplicator::matchAndRewrite( function_ref<bool(const Pattern &)> canApply, function_ref<void(const Pattern &)> onFailure, function_ref<LogicalResult(const Pattern &)> onSuccess) { + // Before checking native patterns, first match against the bytecode. This + // won't automatically perform any rewrites so there is no need to worry about + // conflicts. + SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; + const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); + if (bytecode) + bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); + // Check to see if there are patterns matching this specific operation type. MutableArrayRef<const RewritePattern *> opPatterns; auto patternIt = patterns.find(op->getName()); @@ -98,51 +129,50 @@ LogicalResult PatternApplicator::matchAndRewrite( // Process the patterns for that match the specific operation type, and any // operation type in an interleaved fashion. - // FIXME: It'd be nice to just write an llvm::make_merge_range utility - // and pass in a comparison function. That would make this code trivial. auto opIt = opPatterns.begin(), opE = opPatterns.end(); auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); - while (opIt != opE && anyIt != anyE) { - // Try to match the pattern providing the most benefit. - const RewritePattern *pattern; - if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) - pattern = *(opIt++); - else - pattern = *(anyIt++); + auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end(); + while (true) { + // Find the next pattern with the highest benefit. + const Pattern *bestPattern = nullptr; + const PDLByteCode::MatchResult *pdlMatch = nullptr; + /// Operation specific patterns. + if (opIt != opE) + bestPattern = *(opIt++); + /// Operation agnostic patterns. + if (anyIt != anyE && + (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit())) + bestPattern = *(anyIt++); + /// PDL patterns. + if (pdlIt != pdlE && + (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) { + pdlMatch = pdlIt; + bestPattern = (pdlIt++)->pattern; + } + if (!bestPattern) + break; - // Otherwise, try to match the generic pattern. - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) - return success(); - } - // If we break from the loop, then only one of the ranges can still have - // elements. Loop over both without checking given that we don't need to - // interleave anymore. - for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>( - llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) + // Check that the pattern can be applied. + if (canApply && !canApply(*bestPattern)) + continue; + + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite. For PDL patterns, the + // match has already been performed, we just need to rewrite. + rewriter.setInsertionPoint(op); + LogicalResult result = success(); + if (pdlMatch) { + bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); + } else { + result = static_cast<const RewritePattern *>(bestPattern) + ->matchAndRewrite(op, rewriter); + } + if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern)))) return success(); - } - return failure(); -} -LogicalResult PatternApplicator::matchAndRewrite( - Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, - function_ref<bool(const Pattern &)> canApply, - function_ref<void(const Pattern &)> onFailure, - function_ref<LogicalResult(const Pattern &)> onSuccess) { - // Check that the pattern can be applied. - if (canApply && !canApply(pattern)) - return failure(); - - // Try to match and rewrite this pattern. The patterns are sorted by - // benefit, so if we match we can immediately rewrite. - rewriter.setInsertionPoint(op); - if (succeeded(pattern.matchAndRewrite(op, rewriter))) - return success(!onSuccess || succeeded(onSuccess(pattern))); - - if (onFailure) - onFailure(pattern); + // Perform any necessary cleanups. + if (onFailure) + onFailure(*bestPattern); + } return failure(); } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir new file mode 100644 index 000000000000..b2a22d0a8749 --- /dev/null +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -0,0 +1,785 @@ +// RUN: mlir-opt %s -test-pdl-bytecode-pass -split-input-file | FileCheck %s + +// Note: Tests here are written using the PDL Interpreter dialect to avoid +// unnecessarily testing unnecessary aspects of the pattern compilation +// pipeline. These tests are written such that we can focus solely on the +// lowering/execution of the bytecode itself. + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.apply_constraint "multi_entity_constraint"(%root, %root : !pdl.operation, !pdl.operation) -> ^pat, ^end + + ^pat: + pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.replaced_by_pattern"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_1 +// CHECK: "test.replaced_by_pattern" +module @ir attributes { test.apply_constraint_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyRewriteOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %operand = pdl_interp.get_operand 0 of %root + pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_rewrite_1 +// CHECK: %[[INPUT:.*]] = "test.op_input" +// CHECK-NOT: "test.op" +// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]} +module @ir attributes { test.apply_rewrite_1 } { + %input = "test.op_input"() : () -> i32 + "test.op"(%input) : (i32) -> () +} +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::AreEqualOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %test_attr = pdl_interp.create_attribute unit + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.are_equal_1 +// CHECK: "test.success" +module @ir attributes { test.are_equal_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::BranchOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end + + ^pat1: + pdl_interp.branch ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.branch_1 +// CHECK: "test.success" +module @ir attributes { test.branch_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckAttributeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.check_attribute %attr is unit -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_attribute_1 +// CHECK: "test.success" +module @ir attributes { test.check_attribute_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperandCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_operand_count_1 +// CHECK: "test.op"() : () -> i32 +// CHECK: "test.success" +module @ir attributes { test.check_operand_count_1 } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperationNameOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_operation_name_1 +// CHECK: "test.success" +module @ir attributes { test.check_operation_name_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckResultCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is 1 -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_result_count_1 +// CHECK: "test.success"() : () -> () +module @ir attributes { test.check_result_count_1 } { + "test.op"() : () -> i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.check_type %type is i32 -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_type_1 +// CHECK: "test.success" +module @ir attributes { test.check_type_1 } { + "test.op"() { test_attr = 10 : i32 } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateAttributeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateNativeOp +//===----------------------------------------------------------------------===// + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_native_1 +// CHECK: "test.success" +module @ir attributes { test.create_native_1 } { + "test.op"() : () -> () +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateOperationOp +//===----------------------------------------------------------------------===// + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %test_type = pdl_interp.create_type i32 + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.are_equal %type, %test_type : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_type_1 +// CHECK: "test.success" +module @ir attributes { test.create_type_1 } { + "test.op"() { test_attr = 0 : i32 } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::EraseOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::FinalizeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetDefiningOpOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 5 -> ^pat1, ^end + + ^pat1: + %operand0 = pdl_interp.get_operand 0 of %root + %operand4 = pdl_interp.get_operand 4 of %root + %defOp0 = pdl_interp.get_defining_op of %operand0 + %defOp4 = pdl_interp.get_defining_op of %operand4 + pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_defining_op_1 +// CHECK: %[[OPERAND0:.*]] = "test.op" +// CHECK: %[[OPERAND1:.*]] = "test.op" +// CHECK: "test.success" +// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]]) +module @ir attributes { test.get_defining_op_1 } { + %operand = "test.op"() : () -> i32 + %other_operand = "test.op"() : () -> i32 + "test.op"(%operand, %operand, %operand, %operand, %operand) : (i32, i32, i32, i32, i32) -> () + "test.op"(%operand, %operand, %operand, %operand, %other_operand) : (i32, i32, i32, i32, i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetOperandOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetResultOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end + + ^pat1: + %result0 = pdl_interp.get_result 0 of %root + %result4 = pdl_interp.get_result 4 of %root + %result0_type = pdl_interp.get_value_type of %result0 + %result4_type = pdl_interp.get_value_type of %result4 + pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_result_1 +// CHECK: "test.success" +// CHECK: "test.op"() : () -> (i32, i32, i32, i32, i64) +module @ir attributes { test.get_result_1 } { + %a:5 = "test.op"() : () -> (i32, i32, i32, i32, i32) + %b:5 = "test.op"() : () -> (i32, i32, i32, i32, i64) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetValueTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::InferredTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::IsNotNullOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::RecordMatchOp +//===----------------------------------------------------------------------===// + +// Check that the highest benefit pattern is selected. +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@failure(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @failure(%root : !pdl.operation) { + pdl_interp.erase %root + pdl_interp.finalize + } + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.record_match_1 +// CHECK: "test.success" +module @ir attributes { test.record_match_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::ReplaceOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %operand = pdl_interp.get_operand 0 of %root + pdl_interp.replace %root with (%operand) + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.replace_op_1 +// CHECK: %[[INPUT:.*]] = "test.op_input" +// CHECK-NOT: "test.op" +// CHECK: "test.op_consumer"(%[[INPUT]]) +module @ir attributes { test.replace_op_1 } { + %input = "test.op_input"() : () -> i32 + %result = "test.op"(%input) : (i32) -> i32 + "test.op_consumer"(%result) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchAttributeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.switch_attribute %attr to [0, unit](^end, ^pat) -> ^end + + ^pat: + %attr_2 = pdl_interp.get_attribute "test_attr_2" of %root + pdl_interp.switch_attribute %attr_2 to [0, unit](^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_attribute_1 +// CHECK: "test.success" +module @ir attributes { test.switch_attribute_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperandCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_operand_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end + + ^pat: + pdl_interp.switch_operand_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_operand_1 +// CHECK: "test.success" +module @ir attributes { test.switch_operand_1 } { + %input = "test.op_input"() : () -> i32 + "test.op"(%input) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperationNameOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_operation_name of %root to ["foo.op", "test.op"](^end, ^pat1) -> ^end + + ^pat1: + pdl_interp.switch_operation_name of %root to ["foo.op", "bar.op"](^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_operation_name_1 +// CHECK: "test.success" +module @ir attributes { test.switch_operation_name_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchResultCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_result_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end + + ^pat: + pdl_interp.switch_result_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_result_1 +// CHECK: "test.success" +module @ir attributes { test.switch_result_1 } { + "test.op"() : () -> i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.switch_type %type to [i32, i64](^pat2, ^end) -> ^end + + ^pat2: + pdl_interp.switch_type %type to [i16, i64](^end, ^end) -> ^pat3 + + ^pat3: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_type_1 +// CHECK: "test.success" +module @ir attributes { test.switch_type_1 } { + "test.op"() { test_attr = 10 : i32 } : () -> () +} diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt index 0df357c8c355..9b156867702c 100644 --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Reducer) +add_subdirectory(Rewrite) add_subdirectory(Transforms) diff --git a/mlir/test/lib/Rewrite/CMakeLists.txt b/mlir/test/lib/Rewrite/CMakeLists.txt new file mode 100644 index 000000000000..fd5d5d586160 --- /dev/null +++ b/mlir/test/lib/Rewrite/CMakeLists.txt @@ -0,0 +1,16 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestRewrite + TestPDLByteCode.cpp + + EXCLUDE_FROM_LIBMLIR + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRSupport + MLIRTransformUtils + ) + diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp new file mode 100644 index 000000000000..3b23cb103675 --- /dev/null +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -0,0 +1,85 @@ +//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +/// Custom constraint invoked from PDL. +static LogicalResult customSingleEntityConstraint(PDLValue value, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + Operation *rootOp = value.cast<Operation *>(); + return success(rootOp->getName().getStringRef() == "test.op"); +} +static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + return customSingleEntityConstraint(values[1], constantParams, rewriter); +} + +// Custom creator invoked from PDL. +static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, + PatternRewriter &rewriter) { + return rewriter.createOperation( + OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")); +} + +/// Custom rewriter invoked from PDL. +static void customRewriter(Operation *root, ArrayRef<PDLValue> args, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + OperationState successOpState(root->getLoc(), "test.success"); + successOpState.addOperands(args[0].cast<Value>()); + successOpState.addAttribute("constantParams", constantParams); + rewriter.createOperation(successOpState); + rewriter.eraseOp(root); +} + +namespace { +struct TestPDLByteCodePass + : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { + void runOnOperation() final { + ModuleOp module = getOperation(); + + // The test cases are encompassed via two modules, one containing the + // patterns and one containing the operations to rewrite. + ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns"); + ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir"); + if (!patternModule || !irModule) + return; + + // Process the pattern module. + patternModule.getOperation()->remove(); + PDLPatternModule pdlPattern(patternModule); + pdlPattern.registerConstraintFunction("multi_entity_constraint", + customMultiEntityConstraint); + pdlPattern.registerConstraintFunction("single_entity_constraint", + customSingleEntityConstraint); + pdlPattern.registerCreateFunction("creator", customCreate); + pdlPattern.registerRewriteFunction("rewriter", customRewriter); + + OwningRewritePatternList patternList(std::move(pdlPattern)); + + // Invoke the pattern driver with the provided patterns. + (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), + std::move(patternList)); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestPDLByteCodePass() { + PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass", + "Test PDL ByteCode functionality"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 8857bbe09eef..52e96dc44e0b 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -220,18 +220,21 @@ static void fillL1TilingAndMatmulToVectorPatterns( FuncOp funcOp, StringRef startMarker, SmallVectorImpl<OwningRewritePatternList> &patternsVector) { MLIRContext *ctx = funcOp.getContext(); - patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>( + patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>( ctx, LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}), LinalgMarker(Identifier::get(startMarker, ctx), Identifier::get("L1", ctx)))); - patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx)))); + patternsVector.emplace_back( + std::make_unique<LinalgPromotionPattern<MatmulOp>>( + ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + LinalgMarker(Identifier::get("L1", ctx), + Identifier::get("VEC", ctx)))); - patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>( - ctx, LinalgMarker(Identifier::get("VEC", ctx)))); + patternsVector.emplace_back( + std::make_unique<LinalgVectorizationPattern<MatmulOp>>( + ctx, LinalgMarker(Identifier::get("VEC", ctx)))); patternsVector.back() .insert<LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>>(ctx); @@ -437,7 +440,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp, fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { - stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>( + stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>( ctx, LinalgTilingOptions() .setTileSizes({768, 264, 768}) diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index e8b0842a9e33..8bee2f5faa75 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -19,6 +19,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestIR MLIRTestPass MLIRTestReducer + MLIRTestRewrite MLIRTestTransforms ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 4095cc21cbaf..67aa855092ef 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -86,6 +86,7 @@ void registerTestMemRefStrideCalculation(); void registerTestNumberOfBlockExecutionsPass(); void registerTestNumberOfOperationExecutionsPass(); void registerTestOpaqueLoc(); +void registerTestPDLByteCodePass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); @@ -155,6 +156,7 @@ void registerTestPasses() { test::registerTestNumberOfBlockExecutionsPass(); test::registerTestNumberOfOperationExecutionsPass(); test::registerTestOpaqueLoc(); + test::registerTestPDLByteCodePass(); test::registerTestRecursiveTypesPass(); test::registerTestSCFUtilsPass(); test::registerTestSparsification(); |