[Mlir-commits] [mlir] 18f8928 - [mlir][tosa] Fix mul folder conformance to the spec (#137601)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 02:09:50 PDT 2025


Author: Luke Hutton
Date: 2025-05-08T10:09:46+01:00
New Revision: 18f89283ebac87a153708b8fe00056f96b83022a

URL: https://github.com/llvm/llvm-project/commit/18f89283ebac87a153708b8fe00056f96b83022a
DIFF: https://github.com/llvm/llvm-project/commit/18f89283ebac87a153708b8fe00056f96b83022a.diff

LOG: [mlir][tosa] Fix mul folder conformance to the spec (#137601)

Change the folder for mul with a shift such that the rounding happens
correctly according to the spec
pesudo-code.

Fixes:
https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from:
https://github.com/llvm/llvm-project/pull/128059

Co-authored-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..e73e2c4e33522 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 }
 
 namespace {
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+                            unsigned bitwidth) {
+  APInt result = lhs.sext(64) * rhs.sext(64);
+
+  if (shift > 0) {
+    auto round = APInt(64, 1) << (shift - 1);
+    result += round;
+    result.ashrInPlace(shift);
+    // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+    if (!(result.getSExtValue() >= INT32_MIN &&
+          result.getSExtValue() <= INT32_MAX)) {
+      // REQUIRE failed
+      return std::nullopt;
+    }
+  }
+
+  return result.trunc(bitwidth);
+}
+
 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                   RankedTensorType ty, int32_t shift) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
       }
 
       auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
-      l = l.sext(bitwidth * 2);
-      r = r.sext(bitwidth * 2);
-      auto result = l * r;
-      result.lshrInPlace(shift);
-      result = result.trunc(bitwidth);
-      return DenseElementsAttr::get(ty, result);
+      const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
+      if (!result)
+        return {};
+      return DenseElementsAttr::get(ty, result.value());
     }
 
     if (llvm::isa<FloatType>(ty.getElementType())) {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c98335cdafe65 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
     %1 = tosa.const_shape  {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
     %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
     return %2 : tensor<2x60x58x?xf32>
-  }
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_shift() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_no_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_no_shift() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
+// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}


        


More information about the Mlir-commits mailing list