[Mlir-commits] [mlir] 18f8928 - [mlir][tosa] Fix mul folder conformance to the spec (#137601)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 02:09:50 PDT 2025
Author: Luke Hutton
Date: 2025-05-08T10:09:46+01:00
New Revision: 18f89283ebac87a153708b8fe00056f96b83022a
URL: https://github.com/llvm/llvm-project/commit/18f89283ebac87a153708b8fe00056f96b83022a
DIFF: https://github.com/llvm/llvm-project/commit/18f89283ebac87a153708b8fe00056f96b83022a.diff
LOG: [mlir][tosa] Fix mul folder conformance to the spec (#137601)
Change the folder for mul with a shift such that the rounding happens
correctly according to the spec
pesudo-code.
Fixes:
https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from:
https://github.com/llvm/llvm-project/pull/128059
Co-authored-by: Tai Ly <tai.ly at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..e73e2c4e33522 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
- auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
- return DenseElementsAttr::get(ty, result);
+ const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
+ if (!result)
+ return {};
+ return DenseElementsAttr::get(ty, result.value());
}
if (llvm::isa<FloatType>(ty.getElementType())) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c98335cdafe65 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
%1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
return %2 : tensor<2x60x58x?xf32>
- }
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_no_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_no_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
+// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
More information about the Mlir-commits
mailing list