[Mlir-commits] [mlir] 7d650bf - [mlir][tosa] Fix several bugs in `DepthwiseConv2DIsMul` (#129210)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 3 13:45:04 PST 2025
Author: Longsheng Mou
Date: 2025-03-03T13:45:00-08:00
New Revision: 7d650bf3318b51cee7f89c4792e0f9b36bcdcc46
URL: https://github.com/llvm/llvm-project/commit/7d650bf3318b51cee7f89c4792e0f9b36bcdcc46
DIFF: https://github.com/llvm/llvm-project/commit/7d650bf3318b51cee7f89c4792e0f9b36bcdcc46.diff
LOG: [mlir][tosa] Fix several bugs in `DepthwiseConv2DIsMul` (#129210)
This PR fixes several bugs in `DepthwiseConv2DIsMul`:
- The DepthwiseConv2DOp should restrict the types to integer or float to
prevent a crash.
- `notifyMatchFailure` should be called before creating the new
operations.
Added:
Modified:
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 25a159bbc9644..14ee422a31541 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -48,6 +48,25 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return failure();
}
+ Type inputETy = inputType.getElementType();
+ Type weightETy = weightType.getElementType();
+ if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "unsupported type");
+
+ // Get and verify zero points.
+ int64_t iZp;
+ int64_t wZp;
+
+ if (op.getInputZeroPoint(iZp).failed() ||
+ op.getWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "bail out if zero points cannot statically be determined");
+
+ if (op.verifyInputZeroPoint(iZp).failed() ||
+ op.verifyWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "zero point must be zero for non-int8 integer types");
+
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,8 +81,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
revisedInputShapeValue)
.getResult();
- Type inputETy = inputType.getElementType();
- Type weightETy = weightType.getElementType();
Type resultETy = resultType.getElementType();
if (inputETy != resultETy) {
@@ -76,20 +93,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
}
- // Get and verify zero points.
- int64_t iZp;
- int64_t wZp;
-
- if (op.getInputZeroPoint(iZp).failed() ||
- op.getWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "bail out if zero points cannot statically be determined");
-
- if (op.verifyInputZeroPoint(iZp).failed() ||
- op.verifyWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "zero point must be zero for non-int8 integer types");
-
if (iZp != 0 || wZp != 0) {
auto applyZp = [&](Value val, int64_t zp) -> Value {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 7cd44ba475dbb..3eb72c6bcccf6 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -76,3 +76,25 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
}
+
+// -----
+
+// Decompose only support integer or float types.
+
+// CHECK-LABEL: @depthwise_conv2d_quant_type
+func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
+ %0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: tosa.depthwise_conv2d
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+ return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_no_const_zero_point
+func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<4x10x10x6xi32> {
+ // CHECK: tosa.depthwise_conv2d
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
+ return %0 : tensor<4x10x10x6xi32>
+}
More information about the Mlir-commits
mailing list