summaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2023-02-21 21:04:00 +0000
committerAlex Zinenko <zinenko@google.com>2023-05-16 08:16:56 +0000
commit2fe4d90cac54bf35948bea2ca6d5d8c510f6a1b4 (patch)
tree2e22324deee5b3e04182bbe0db07edecc5f8d118 /mlir/include
parentaf0121fb8f793e5142d445cc2192e5c4a33bb21f (diff)
downloadllvm-2fe4d90cac54bf35948bea2ca6d5d8c510f6a1b4.tar.gz
[mlir] make structured transform ops use types
Types have been introduced a while ago and provide for better readability and transform-time verification. Use them in the ops from the structured transform dialect extension. In most cases, the types are appended as trailing functional types or a derived format of the functional type that allows for an empty right hand size without the annoying `-> ()` syntax (similarly to `func.func` declaration that may omit the arrow). When handles are used inside mixed static/dynamic lists, such as tile sizes, types of those handles follow them immediately as in `sizes [%0 : !transform.any_value, 42]`. This allows for better readability than matching the trailing type. Update code to remove hardcoded PDL dependencies and expunge PDL from structured transform op code. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D144515
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td187
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h47
-rw-r--r--mlir/include/mlir/Dialect/Transform/Utils/Utils.h18
-rw-r--r--mlir/include/mlir/Interfaces/ViewLikeInterface.h40
4 files changed, 192 insertions, 100 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c7bc3767b27c..4f78b7d6c80d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -14,7 +14,6 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -96,15 +95,16 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
#### Return modes
This operation ignores non-Linalg ops and drops them in the return.
- If all the operations referred to by the `target` PDLOperation decompose
+ If all the operations referred to by the `target` handle decompose
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
computational operations, which can be empty.
}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -127,11 +127,11 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
- let results = (outs PDL_Operation:$transformed,
- Variadic<PDL_Operation>:$loops);
+ let results = (outs TransformHandleTypeInterface:$transformed,
+ Variadic<TransformHandleTypeInterface>:$loops);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
@@ -181,10 +181,11 @@ def FuseIntoContainingOp :
This operation only reads the containing op handle.
}];
- let arguments = (ins PDL_Operation:$producer_op,
- PDL_Operation:$containing_op);
- let results = (outs PDL_Operation:$fused_op);
- let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+ let arguments = (ins TransformHandleTypeInterface:$producer_op,
+ TransformHandleTypeInterface:$containing_op);
+ let results = (outs TransformHandleTypeInterface:$fused_op);
+ let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+ " `:` functional-type(operands, results)";
let builders = [
OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
@@ -205,16 +206,18 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
#### Return modes
This operation ignores non-Linalg ops and drops them in the return.
- If all the operations referred to by the `target` PDLOperation generalize
+ If all the operations referred to by the `target` handle generalize
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
equivalent generic operations, which can be empty or contain the original
ops if they were already in generic form.
}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat =
+ "$target attr-dict `:` "
+ "custom<SemiFunctionType>(type($target), type($transformed))";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -239,7 +242,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
This operation ignores non-linalg::Generic ops and drops them in the return.
This operation fails if the interchange attribute is invalid.
- If all the operations referred to by the `target` PDLOperation interchange
+ If all the operations referred to by the `target` handle interchange
properly, the transform succeeds.
If any interchange fails, the transform definitely fails.
The return handle points to only the subset of successfully produced
@@ -247,14 +250,15 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins TransformHandleTypeInterface:$target,
ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">,
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
+ `:` custom<SemiFunctionType>(type($target), type($transformed))
}];
let hasVerifier = 1;
@@ -552,13 +556,14 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<PDL_Operation>:$packed_sizes,
+ Variadic<TransformHandleTypeInterface>:$packed_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_packed_sizes);
let results = (outs TransformHandleTypeInterface:$packed_op);
let assemblyFormat = [{
$target
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
- $static_packed_sizes)
+ $static_packed_sizes,
+ type($packed_sizes))
attr-dict
`:` functional-type($target, results)
}];
@@ -637,7 +642,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
// TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<PDL_Operation>:$matmul_packed_sizes,
+ Variadic<TransformHandleTypeInterface>:$matmul_packed_sizes,
ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
[DenseArrayCount<3>]>:$static_matmul_packed_sizes,
ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
@@ -662,7 +667,8 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
$target
oilist(
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
- $static_matmul_packed_sizes)
+ $static_matmul_packed_sizes,
+ type($matmul_packed_sizes))
(`matmul_padded_sizes_next_multiple_of` `=`
$matmul_padded_sizes_next_multiple_of^)?
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
@@ -758,23 +764,25 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
This operation ignores non-Linalg ops and drops them in the return.
This operation may produce a definiteFailure if the padding fails for any
reason.
- If all the operations referred to by the `target` PDLOperation pad
+ If all the operations referred to by the `target` handle pad
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
padded operations, which can be empty.
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
DefaultValuedAttr<
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
"{}">:$transpose_paddings);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:` "
+ "custom<SemiFunctionType>(type($target), type($transformed))";
let hasVerifier = 1;
let extraClassDeclaration = [{
@@ -898,23 +906,25 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
This operation applies to a single Linalg op that satisfies the
`promoteSubviewsPrecondition`, otherwise it fails.
- If the operations referred to by the `target` PDLOperation promote
+ If the operations referred to by the `target` handle promote
properly, the transform succeeds.
When successful, the return handle points to the $target operation that
was modified inplace.
}];
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$operands_to_promote,
DefaultValuedAttr<BoolArrayAttr, "{}">:$use_full_tile_buffers,
UnitAttr:$use_full_tiles_by_default,
UnitAttr:$use_alloca,
OptionalAttr<DeviceMappingArrayAttr>:$mapping,
OptionalAttr<I64Attr>:$alignment);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:`"
+ "custom<SemiFunctionType>(type($target), type($transformed))";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -943,10 +953,12 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
This operation consumes the `target` handle.
}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$replacement);
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
- let assemblyFormat = "$target attr-dict-with-keyword regions";
+ let assemblyFormat =
+ "$target attr-dict-with-keyword regions `:` "
+ "custom<SemiFunctionType>(type($target), type($replacement))";
let hasVerifier = 1;
}
@@ -966,7 +978,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
This operation ignores non-Linalg ops and drops them in the return.
This operation produces `definiteFailure` if the scalarization fails for any
reason.
- If all the operations referred to by the `target` PDLOperation scalarize
+ If all the operations referred to by the `target` handle scalarize
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
@@ -980,10 +992,12 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
needed.
}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$result);
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:`"
+ "custom<SemiFunctionType>(type($target), type($result))";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1016,13 +1030,13 @@ def RewriteInDestinationPassingStyleOp : Op<
#### Return modes
This operation ignores non-unsupported ops and drops them from the return.
- If all the operations referred to by the `target` PDLOperation generalize
+ If all the operations referred to by the `target` handle generalize
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to a subset of successfully produced operations:
- - tensor.pad case, the returned handle points to the tensor.insert_slice.
- - tensor.generate case, the returned handle points to the linalg.generic.
- - tensor.from_elements case, the returned handle points to the last
- tensor.insert.
+ - `tensor.pad` case, the returned handle points to the tensor.insert_slice.
+ - `tensor.generate` case, the returned handle points to the linalg.generic.
+ - `tensor.from_elements` case, the returned handle points to the last
+ `tensor.insert`.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
@@ -1110,7 +1124,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
This operation produces `definiteFailure` if the splitting fails for any
reason.
- If all the operations referred to by the `target` PDLOperation split
+ If all the operations referred to by the `target` handle split
properly, the transform succeeds. Otherwise the transform silently fails.
The 4 returned handles points to only the subset of successfully produced
computational operations, which can all be empty.
@@ -1219,18 +1233,20 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
```
}];
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
UnitAttr:$inner_parallel,
UnitAttr:$use_scaling_algorithm,
UnitAttr:$use_alloc);
- let results = (outs PDL_Operation:$init_or_alloc_op,
- PDL_Operation:$fill_op,
- PDL_Operation:$split_linalg_op,
- PDL_Operation:$combining_linalg_op);
+ let results = (outs TransformHandleTypeInterface:$init_or_alloc_op,
+ TransformHandleTypeInterface:$fill_op,
+ TransformHandleTypeInterface:$split_linalg_op,
+ TransformHandleTypeInterface:$combining_linalg_op);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:`"
+ "functional-type(operands, results)";
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1326,12 +1342,12 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
}];
// TODO: support mixed static-dynamic (see TileToForallOp).
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
- let results = (outs PDL_Operation:$for_op,
- PDL_Operation:$fill_op,
- PDL_Operation:$split_linalg_op,
- PDL_Operation:$combining_linalg_op);
+ let results = (outs TransformHandleTypeInterface:$for_op,
+ TransformHandleTypeInterface:$fill_op,
+ TransformHandleTypeInterface:$split_linalg_op,
+ TransformHandleTypeInterface:$combining_linalg_op);
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1342,6 +1358,7 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
$target
`by` `tile_sizes` `=` $tile_sizes
attr-dict
+ `:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
@@ -1427,14 +1444,14 @@ def TileReductionUsingForallOp :
}];
// TODO: support mixed static-dynamic (see TileToForallOp).
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
- let results = (outs PDL_Operation:$forall_op,
- PDL_Operation:$fill_op,
- PDL_Operation:$split_linalg_op,
- PDL_Operation:$combining_linalg_op);
+ let results = (outs TransformHandleTypeInterface:$forall_op,
+ TransformHandleTypeInterface:$fill_op,
+ TransformHandleTypeInterface:$split_linalg_op,
+ TransformHandleTypeInterface:$combining_linalg_op);
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1450,6 +1467,7 @@ def TileReductionUsingForallOp :
(`,` `tile_sizes` `=` $tile_sizes^)?
(`,` `mapping` `=` $mapping^)?
attr-dict
+ `:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
@@ -1577,7 +1595,7 @@ def TileToForallOp :
This operation ignores ops that do not implement the TilingInterface and
drops them in the return.
- If all the operations referred to by the `target` PDLOperation tile
+ If all the operations referred to by the `target` handle tile
successfully, the transform succeeds.
Otherwise the transform silently fails.
@@ -1604,16 +1622,16 @@ def TileToForallOp :
```
}];
- let arguments = (ins PDL_Operation:$target,
- Variadic<PDL_Operation>:$num_threads,
- Variadic<PDL_Operation>:$tile_sizes,
- Optional<PDL_Operation>:$packed_num_threads,
- Optional<PDL_Operation>:$packed_tile_sizes,
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<TransformHandleTypeInterface>:$num_threads,
+ Variadic<TransformHandleTypeInterface>:$tile_sizes,
+ Optional<TransformHandleTypeInterface>:$packed_num_threads,
+ Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
- let results = (outs PDL_Operation:$forall_op,
- PDL_Operation:$tiled_op);
+ let results = (outs TransformHandleTypeInterface:$forall_op,
+ TransformHandleTypeInterface:$tiled_op);
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1641,12 +1659,17 @@ def TileToForallOp :
let assemblyFormat = [{
$target oilist(
`num_threads` custom<PackedOrDynamicIndexList>($packed_num_threads,
+ type($packed_num_threads),
$num_threads,
+ type($num_threads),
$static_num_threads) |
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
+ type($packed_tile_sizes),
$tile_sizes,
+ type($tile_sizes),
$static_tile_sizes))
(`(` `mapping` `=` $mapping^ `)`)? attr-dict
+ `:` functional-type($target, results)
}];
let hasVerifier = 1;
@@ -1705,12 +1728,12 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
produces a definite failure.
}];
- let arguments = (ins PDL_Operation:$target,
- Variadic<PDL_Operation>:$dynamic_sizes,
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<TransformHandleTypeInterface>:$dynamic_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
- let results = (outs PDL_Operation:$tiled_linalg_op,
- Variadic<PDL_Operation>:$loops);
+ let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
+ Variadic<TransformHandleTypeInterface>:$loops);
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1760,7 +1783,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
- `disable_transfer_permutation_map_lowering_patterns`: a UnitAttr to
deactivate the rewrite of `vector.transfer` with permutation maps into
explicit `vector.transpose` operations. This is intended to be used in
- tests only but may be promotoed to a first class attribute in the future.
+ tests only but may be promoted to a first class attribute in the future.
#### Return modes:
@@ -1770,14 +1793,16 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
to be isolated from above.
}];
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$disable_multi_reduction_to_contract_patterns,
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:`"
+ "functional-type(operands, results)";
let builders = [
OpBuilder<(ins "Value":$target,
@@ -1812,13 +1837,13 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
#### Return modes:
This operation produces a definite failure if the dynamic vector sizes (SSA
- values) do not satify the constraints mentioned above. It produces a
+ values) do not satisfy the constraints mentioned above. It produces a
silenceable failure if at least one target op is not a Linalg op or fails to
vectorize.
}];
- let arguments = (ins PDL_Operation:$target,
- Variadic<PDL_Operation>:$vector_sizes,
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<TransformHandleTypeInterface>:$vector_sizes,
UnitAttr:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes);
@@ -1826,8 +1851,10 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
let assemblyFormat = [{
$target
`vector_sizes` custom<DynamicIndexList>($vector_sizes,
- $static_vector_sizes)
+ $static_vector_sizes,
+ type($vector_sizes))
attr-dict
+ `:` type($target)
}];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
new file mode 100644
index 000000000000..13b0dc0e0b95
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -0,0 +1,47 @@
+//===- Syntax.h - Custom syntax for Linalg transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
+#define MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ParseResult;
+class OpAsmParser;
+class OpAsmPrinter;
+class Type;
+class TypeRange;
+class Operation;
+
+/// Parses a single non-function type or a function type with at least one
+/// argument. This allows for the following syntax:
+///
+/// - type: just the argument type;
+/// - `(` type `)` `->` type: one argument and one result type;
+/// - `(` type `)` `->` `(` comma-separated-type-list `)`: one argument and
+/// multiple result types.
+///
+/// Unlike FunctionType, this allows and requires one to omit the parens around
+/// the argument type in absence of result types, and does not accept the
+/// trailing `-> ()` construct, which makes the syntax nicer for operations.
+ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
+ Type &resultType);
+ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
+ SmallVectorImpl<Type> &resultTypes);
+
+/// Prints argument and result types in a syntax similar to that of FunctionType
+/// but allowing and requiring one to omit the parens around the argument type
+/// in absence of result types, and without the trailing `-> ()`.
+void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
+ Type argumentType, TypeRange resultType);
+void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
+ Type argumentType, Type resultType);
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/Transform/Utils/Utils.h b/mlir/include/mlir/Dialect/Transform/Utils/Utils.h
index 04a0b090e6b9..97b193bc723d 100644
--- a/mlir/include/mlir/Dialect/Transform/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Transform/Utils/Utils.h
@@ -22,7 +22,8 @@ class TransformState;
/// Printer hook for custom directive in assemblyFormat.
///
-/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+/// custom<PackedOrDynamicIndexList>($packed, type($packed), $values,
+/// type($values), $integers)
///
/// where `values` are variadic Index values, `integers` is an `I64ArrayAttr`
/// and `packed` is a single transform dialect handle who's mapped payload ops
@@ -30,20 +31,23 @@ class TransformState;
/// or the other two parameters may be specified.
///
/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list or with a single handle. E.g., `[%arg0, 7, 42, %arg42]` or just `%h`.
+/// list or with a single handle. E.g., `[%arg0 : !transform.any_op, 7, 42,
+/// %arg42 : !transform.param<i64>]` or just `%h : !transform.any_op`.
void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
- Value packed, OperandRange values,
+ Value packed, Type packedType,
+ OperandRange values, TypeRange valueTypes,
ArrayRef<int64_t> integers);
-/// Pasrer hook for custom directive in assemblyFormat.
+/// Parser hook for custom directive in assemblyFormat.
///
-/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+/// custom<PackedOrDynamicIndexList>($packed, type($packed), $values,
+/// type($values), $integers)
///
/// See `printPackedOrDynamicIndexList` for details.
ParseResult parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers);
+ Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers);
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 5843ecd061df..87113197524f 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -42,35 +42,49 @@ namespace mlir {
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<DynamicIndexList>($values, $integers)
+/// custom<DynamicIndexList>($values, $integers, type($values))
///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
-/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
-/// idiomatic printing of mixed value and integer attributes in a list. E.g.
-/// `[%arg0, 7, 42, %arg42]`.
+/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
+/// is non-empty, it is expected to contain as many elements as `values`
+/// indicating their types. This allows idiomatic printing of mixed value and
+/// integer attributes in a list. E.g.
+/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
- ArrayRef<int64_t> integers,
+ ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
-/// Pasrer hook for custom directive in assemblyFormat.
+/// Parser hook for custom directive in assemblyFormat.
///
/// custom<DynamicIndexList>($values, $integers)
+/// custom<DynamicIndexList>($values, $integers, type($values))
///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
-/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
-/// to `values` in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+/// `kDynamic` encodes the position of SSA values. Add the parsed SSA values
+/// to `values` in-order. If `valueTypes` is non-null, fill it with types
+/// corresponding to values; otherwise the caller must handle the types.
+///
+/// E.g. after parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
+/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
+/// `kDynamic`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult parseDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ return parseDynamicIndexList(parser, values, integers, &valueTypes,
+ delimiter);
+}
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.