[Mlir-commits] [mlir] [mlir][tosa] Add pass to assign static input shape to TOSA functions (PR #171156)

Sayan Saha llvmlistbot at llvm.org
Wed Dec 10 05:08:35 PST 2025


================
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-experimental-input-shape="args=arg0:2x16,arg2:64x9" %s | FileCheck %s
+
+// CHECK-LABEL: test_empty_func
+func.func @test_empty_func(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %arg0, %arg1, %arg2 : tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>
+    return %arg0, %arg1, %arg2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_func_with_ops
+func.func @test_func_with_ops(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %[[ADD:.*]] = tosa.add %arg0, %arg0 : (tensor<2x16xi32>, tensor<2x16xi32>)
+    %0 = tosa.add %arg0, %arg0 : (tensor<2x?xi32>, tensor<2x?xi32>) -> tensor<2x?xi32>
+    // CHECK: %[[RECIP:.*]] =  tosa.reciprocal %arg1 : (tensor<?x256xf32>)
+    %1 = tosa.reciprocal %arg1 : (tensor<?x256xf32>) -> tensor<?x256xf32>
+    // CHECK: %[[SUB:.*]] = tosa.sub %arg2, %arg2 : (tensor<64x9xf32>, tensor<64x9xf32>)
+    %2 = tosa.sub %arg2, %arg2 : (tensor<?x9xf32>, tensor<?x9xf32>) -> tensor<?x9xf32>
+    return %0, %1, %2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_controlflow
+func.func @test_controlflow(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>,
+        // CHECK: %arg3: tensor<i1>
+        %arg3: tensor<i1>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %[[IF:.*]]:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
----------------
sahas3 wrote:

Oh I see, thanks for the clarification. I think without running `tosa-infer-shapes` this IR cannot be lowered out of tosa.

Following works
```
$> mlir-opt --pass-pipeline='builtin.module(func.func(tosa-infer-shapes,  tosa-to-scf))' 

module {
  func.func @test_controlflow(%arg0: tensor<2x16xi32>, %arg1: tensor<?x256xf32>, %arg2: tensor<64x9xf32>, %arg3: tensor<i1>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
    %extracted = tensor.extract %arg3[] : tensor<i1>
    %0:3 = scf.if %extracted -> (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) {
      scf.yield %arg0, %arg1, %arg2 : tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>
    } else {
      scf.yield %arg0, %arg1, %arg2 : tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>
    }
    %cast = tensor.cast %0#2 : tensor<64x9xf32> to tensor<?x9xf32>
    %cast_0 = tensor.cast %0#0 : tensor<2x16xi32> to tensor<2x?xi32>
    return %cast_0, %0#1, %cast : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
  }
}
```

but without inferring shapes it fails:
```
$> mlir-opt --pass-pipeline='builtin.module(func.func(tosa-to-scf))' 

error: 'scf.if' op along control flow edge from Operation scf.yield to parent results: source type #0 'tensor<2x16xi32>' should match input type #0 'tensor<2x?xi32>'
    %0:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
           ^
/tmp/tosa.mlir:9:12: note: see current operation: 
%1:3 = "scf.if"(%0) ({
  "scf.yield"(%arg0, %arg1, %arg2) : (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> ()
}, {
  "scf.yield"(%arg0, %arg1, %arg2) : (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> ()
}) : (i1) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>)
```

Is there any requirement that a pass cannot produce illegal IR? Otherwise this is fine.

https://github.com/llvm/llvm-project/pull/171156


More information about the Mlir-commits mailing list