[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Jan 10 09:10:34 PST 2025


================
@@ -69,3 +69,35 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :
 // CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
 // CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
 // CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: test_broadcast_scalar_across_single_tensor
+// CHECK-SAME: () -> tensor<2x2xi32> {
+// CHECK:   %[[E0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%cst, %c42_i32 : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {
+// CHECK:   ^bb0(%in: i32, %in_0: i32, %out: i32):
+// CHECK:     %[[E0:.+]] = arith.addi %in, %in_0 : i32
+// CHECK:     linalg.yield %2 : i32
+// CHECK:   } -> tensor<2x2xi32>
----------------
adam-smnk wrote:

I think it might be too much detail in this case.
A simple `// CHECK: linalg.generic` should be sufficient to validate that specialization did nothing.

FYI, in tests when you care about these details like SSA values populating maps, ins etc., it is best to properly capture them from the graph directly through other checks. For example:
```
// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i32
// CHECK: linalg.generic{{.*}}ins(%[[C42]],
```
 Otherwise, these plain printed names like `%c42_i32` might change which makes the test unreliable.



https://github.com/llvm/llvm-project/pull/122449


More information about the Mlir-commits mailing list