[Mlir-commits] [mlir] 8432f24 - [mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic (#153420)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 13 16:45:09 PDT 2025


Author: Sayan Saha
Date: 2025-08-13T19:45:06-04:00
New Revision: 8432f24831b21c34eb8c03b7cdc75c6f59326395

URL: https://github.com/llvm/llvm-project/commit/8432f24831b21c34eb8c03b7cdc75c6f59326395
DIFF: https://github.com/llvm/llvm-project/commit/8432f24831b21c34eb8c03b7cdc75c6f59326395.diff

LOG: [mlir][tosa] Don't fold mul with zero lhs/rhs if resulting type is dynamic (#153420)

Canonicalizing the following IR:

```
func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
  return %2 : tensor<?x17xf32>
}
```

resulted in a crash

```
#0 0x000056513187e8db backtrace (./build-release/bin/mlir-opt+0x9d698db)                                                                                                                                                                                                                                                                                                                   
 #1 0x0000565131b17737 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:838:8                                                                                                                                                                                                                
 #2 0x0000565131b187f3 PrintStackTraceSignalHandler(void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:918:1                                                                                                                                                                                                                                
 #3 0x0000565131b18c30 llvm::sys::RunSignalHandlers() /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Signals.cpp:105:18                                                                                                                                                                                                                                         
 #4 0x0000565131b18c30 SignalHandler(int, siginfo_t*, void*) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/llvm/lib/Support/Unix/Signals.inc:409:3                                                                                                                                                                                                                              
 #5 0x00007f2e4165b050 (/lib/x86_64-linux-gnu/libc.so.6+0x3c050)                                                                                                                                                                                                                                                                                                                            
 #6 0x00007f2e416a9eec __pthread_kill_implementation ./nptl/pthread_kill.c:44:76                                                                                                                                                                                                                                                                                                            
 #7 0x00007f2e4165afb2 raise ./signal/../sysdeps/posix/raise.c:27:6                                                                                                                                                                                                                                                                                                                         
 #8 0x00007f2e41645472 abort ./stdlib/abort.c:81:7                                                                                                                                                                                                                                                                                                                                          
 #9 0x00007f2e41645395 _nl_load_domain ./intl/loadmsgcat.c:1177:9                                                                                                                                                                                                                                                                                                                           
#10 0x00007f2e41653ec2 (/lib/x86_64-linux-gnu/libc.so.6+0x34ec2)                                                                                                                                                                                                                                                                                                                            
#11 0x00005651443ec4ba mlir::DenseIntOrFPElementsAttr::getRaw(mlir::ShapedType, llvm::ArrayRef<char>) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:1361:3                                                                                                                                                                                    
#12 0x00005651443f1209 mlir::DenseElementsAttr::resizeSplat(mlir::ShapedType) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/IR/BuiltinAttributes.cpp:0:10                                                                                                                                                                                                              
#13 0x000056513f76f2b6 mlir::tosa::MulOp::fold(mlir::tosa::MulOpGenericAdaptor<llvm::ArrayRef<mlir::Attribute>>) /local-ssd/sayans/Softwares/llvm-repo/llvm-project-latest/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp:0:0
```

from the folder for `tosa::mul` since the zero value was being reshaped
to `?x17` size which isn't supported. AFAIK, `tosa.const` requires all
dimensions to be static. So in this case, the fix is to not to fold the
op.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..fce61f27ca3ea 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   }
 
   if (rhsTy == resultTy) {
-    if (isSplatZero(resultETy, lhsAttr))
+    if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+      // constant values can only be resized if resulting type is static
       return lhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, lhsAttr, shift))
       return rhs;
   }
   if (lhsTy == resultTy) {
-    if (isSplatZero(resultETy, rhsAttr))
+    if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
       return rhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, rhsAttr, shift))
       return lhs;

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee36e9e5e..930bb9fe96811 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
 
 // -----
 
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK:           %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK:           %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK:           return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK:           return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+  %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+  return %2 : tensor<?x17xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>


        


More information about the Mlir-commits mailing list