[Mlir-commits] [mlir] [mlir][TosaToLinalg] `RescaleConverter` only support integer type (PR #114239)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 30 07:32:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR fixes a bug in the `RescaleConverter` that allows non-integer types, which leads to a crash.
Fixes #<!-- -->61383.

---
Full diff: https://github.com/llvm/llvm-project/pull/114239.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 251c48859d2de0..5291f95d371442 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1153,6 +1153,9 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
+    if (!isa<IntegerType>(inputTy.getElementType()))
+      return rewriter.notifyMatchFailure(op, "only support integer type");
+
     SmallVector<Value> dynDims;
     for (int i = 0; i < outputTy.getRank(); i++) {
       if (outputTy.isDynamicDim(i)) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index b78577275a52a6..ea1b79cbd9507f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,3 +36,12 @@ func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x
   %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
   return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @rescale_unsupported_type
+func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
+  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/114239


More information about the Mlir-commits mailing list