summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederik Gossen <frgossen@google.com>2020-09-09 07:44:38 +0000
committerFrederik Gossen <frgossen@google.com>2020-09-09 07:45:46 +0000
commit133322d2e30877d5039643ab5c2ed02f75c29466 (patch)
tree25fa9150498d31cf69afa9f62de12baab90ec23f
parentfdc8a1aac293084ffb2d7f04b1225c8e2fb3b164 (diff)
downloadllvm-133322d2e30877d5039643ab5c2ed02f75c29466.tar.gz
[MLIR][Standard] Update `tensor_from_elements` assembly format
Remove the redundant parenthesis that are used for none of the other operation formats. Differential Revision: https://reviews.llvm.org/D86287
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/IR/Ops.td11
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp18
-rw-r--r--mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir6
-rw-r--r--mlir/test/IR/core-ops.mlir12
-rw-r--r--mlir/test/IR/invalid-ops.mlir4
-rw-r--r--mlir/test/Transforms/canonicalize.mlir2
6 files changed, 28 insertions, 25 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index f326ae557865..c276818589af 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1621,14 +1621,9 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements",
let results = (outs AnyTensor:$result);
let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "OpBuilder &builder, OperationState &result, ValueRange elements", [{
- assert(!elements.empty() && "expected at least one element");
- result.addOperands(elements);
- result.addTypes(
- RankedTensorType::get({static_cast<int64_t>(elements.size())},
- *elements.getTypes().begin()));
- }]>];
+ let builders = [
+ OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
+ ];
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 65f8b83d9a71..1c6901987019 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1744,9 +1744,9 @@ static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
Type resultType;
- if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
- parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColon() || parser.parseType(resultType))
+ if (parser.parseOperandList(elementsOperands) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(resultType))
return failure();
if (parser.resolveOperands(elementsOperands,
@@ -1759,9 +1759,9 @@ static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
- p << "tensor_from_elements(" << op.elements() << ')';
+ p << "tensor_from_elements " << op.elements();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.result().getType();
+ p << " : " << op.getType();
}
static LogicalResult verify(TensorFromElementsOp op) {
@@ -1778,6 +1778,14 @@ static LogicalResult verify(TensorFromElementsOp op) {
return success();
}
+void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
+ ValueRange elements) {
+ assert(!elements.empty() && "expected at least one element");
+ result.addOperands(elements);
+ result.addTypes(RankedTensorType::get({static_cast<int64_t>(elements.size())},
+ *elements.getTypes().begin()));
+}
+
namespace {
// Canonicalizes the pattern of the form
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index bf8e74e5143e..4d2437a4877b 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -94,7 +94,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
- // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
+ // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
@@ -223,7 +223,7 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex>
%shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return
}
@@ -238,7 +238,7 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// CHECK-DAG: %[[C5:.*]] = constant 5 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex>
%shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return
}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 69e974bc4173..e4472b444f03 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -661,17 +661,17 @@ func @extract_element(%arg0: tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
// CHECK-LABEL: func @tensor_from_elements() {
func @tensor_from_elements() {
%c0 = "std.constant"() {value = 0: index} : () -> index
- // CHECK: %0 = tensor_from_elements(%c0) : tensor<1xindex>
- %0 = tensor_from_elements(%c0) : tensor<1xindex>
+ // CHECK: %0 = tensor_from_elements %c0 : tensor<1xindex>
+ %0 = tensor_from_elements %c0 : tensor<1xindex>
%c1 = "std.constant"() {value = 1: index} : () -> index
- // CHECK: %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
- %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
+ // CHECK: %1 = tensor_from_elements %c0, %c1 : tensor<2xindex>
+ %1 = tensor_from_elements %c0, %c1 : tensor<2xindex>
%c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32
// CHECK: [[C0_F32:%.*]] = constant
- // CHECK: %2 = tensor_from_elements([[C0_F32]]) : tensor<1xf32>
- %2 = tensor_from_elements(%c0_f32) : tensor<1xf32>
+ // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
+ %2 = tensor_from_elements %c0_f32 : tensor<1xf32>
return
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 55739119aa26..71b007ef6e39 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -597,7 +597,7 @@ func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
func @tensor_from_elements_wrong_result_type() {
// expected-error@+2 {{expected result type to be a ranked tensor}}
%c0 = constant 0 : i32
- %0 = tensor_from_elements(%c0) : tensor<*xi32>
+ %0 = tensor_from_elements %c0 : tensor<*xi32>
return
}
@@ -606,7 +606,7 @@ func @tensor_from_elements_wrong_result_type() {
func @tensor_from_elements_wrong_elements_count() {
// expected-error@+2 {{expected result type to be a 1D tensor with 1 element}}
%c0 = constant 0 : index
- %0 = tensor_from_elements(%c0) : tensor<2xindex>
+ %0 = tensor_from_elements %c0 : tensor<2xindex>
return
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 7333446c6e5d..76fe82588be3 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -981,7 +981,7 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
func @extract_element_from_tensor_from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
%c0 = constant 0 : index
- %tensor = tensor_from_elements(%element) : tensor<1xindex>
+ %tensor = tensor_from_elements %element : tensor<1xindex>
%extracted_element = extract_element %tensor[%c0] : tensor<1xindex>
// CHECK: [[ARG]] : index
return %extracted_element : index