[Mlir-commits] [mlir] [mlir][tosa] Fix mul folder conformance to the spec (PR #137601)
Luke Hutton
llvmlistbot at llvm.org
Mon Apr 28 02:09:18 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/137601
Change the folder for mul with a shift such that the rounding happens correctly according to the spec
pesudo-code.
Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: https://github.com/llvm/llvm-project/pull/128059
Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
>From 95d0be212167c79c001740b12ae3e5dee8b6af2d Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 25 Apr 2025 15:58:33 +0000
Subject: [PATCH] [mlir][tosa] Fix mul folder conformance to the spec
Change the folder for mul with a shift such that the
rounding happens correctly according to the spec
pesudo-code.
Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: https://github.com/llvm/llvm-project/pull/128059
Co-authored-by: Tai Ly <tai.ly at arm.com>
Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 31 +++++++++++---
mlir/test/Dialect/Tosa/canonicalize.mlir | 41 ++++++++++++++++++-
2 files changed, 65 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..e73e2c4e33522 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
- auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
- return DenseElementsAttr::get(ty, result);
+ const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
+ if (!result)
+ return {};
+ return DenseElementsAttr::get(ty, result.value());
}
if (llvm::isa<FloatType>(ty.getElementType())) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c98335cdafe65 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
%1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
return %2 : tensor<2x60x58x?xf32>
- }
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_no_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_no_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
+// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
More information about the Mlir-commits
mailing list