[Mlir-commits] [mlir] 949148c - [mlir][tosa] Fix argmax folder when output type is i64 (#163583)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 22 09:35:26 PDT 2025


Author: Luke Hutton
Date: 2025-10-22T17:35:22+01:00
New Revision: 949148c1f33b96cf8893741e16129684e595b0ce

URL: https://github.com/llvm/llvm-project/commit/949148c1f33b96cf8893741e16129684e595b0ce
DIFF: https://github.com/llvm/llvm-project/commit/949148c1f33b96cf8893741e16129684e595b0ce.diff

LOG: [mlir][tosa] Fix argmax folder when output type is i64 (#163583)

Previously the following IR:
```
tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
```
Would result in a crash with the assertion:
```
expected dense element bit width 64 to match data size 32 for type i64
```

This commit ensures that zero is constructed with the correct bitwidth
while folding, therefore fixing the crash.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf80165fc640..99b7cda49094e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
       !outputTy.hasStaticShape())
     return {};
 
-  if (inputTy.getDimSize(getAxis()) == 1)
-    return DenseElementsAttr::get(outputTy, 0);
+  const Type outputElementTy = getElementTypeOrSelf(outputTy);
+  if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+    const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+    const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+    return DenseElementsAttr::get(outputTy, zero);
+  }
 
   return {};
 }

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e8525a5d2ed62..7574afa215e78 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
 
 // -----
 
+// CHECK-LABEL: @test_argmax_fold_i64_index
+func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
+  // CHECK: "tosa.const"() <{values = dense<0> : tensor<i64>}> : () -> tensor<i64>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
+  return %0 : tensor<i64>
+}
+
+// -----
+
 // CHECK-LABEL: @pad_wh_avg_pool2d_fold
 func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
   // CHECK-NOT: tosa.pad


        


More information about the Mlir-commits mailing list