summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp')
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7b4733864972..44f64f76e9b0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -34,9 +34,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value weight = op.getWeight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ ShapedType weightType = cast<ShapedType>(weight.getType());
+ ShapedType resultType = cast<ShapedType>(op.getType());
auto numDynamic =
llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
@@ -66,7 +66,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto quantizationInfo = op.getQuantizationInfo();
int64_t iZp = quantizationInfo->getInputZp();
- if (!validIntegerRange(inputETy.cast<IntegerType>(), iZp))
+ if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
@@ -116,7 +116,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(weight.getType()).getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,