[Mlir-commits] [mlir] [mlir][scf]Fix scf.forall inlining: add shared outputs (PR #132197)

Matthias Springer llvmlistbot at llvm.org
Fri Mar 21 01:00:38 PDT 2025


================
@@ -55,3 +55,26 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
   }
   return
 }
+
+// -----
+
+  func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
+    %c100 = arith.constant 100 : index
+    %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
+      %t = "test.foo"() : () -> tensor<100xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
+      }
+    }
+    return %res : tensor<100xf32>
+  }
+// CHECK-LABEL:   func.func @parallel_insert_slice(
+// CHECK-SAME:      %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<100xf32>) -> tensor<100xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_0]] : tensor<100xf32>
----------------
matthias-springer wrote:

Basically, instead of dropping the terminator of the `scf.forall` loop, you have to replace it with `tensor.insert_slice` and yield the result. Also, the loop nest that this pass is generating must have an iter_arg (and result); one per shared_out.

https://github.com/llvm/llvm-project/pull/132197


More information about the Mlir-commits mailing list