[Mlir-commits] [mlir] [mlir][scf] Fix `FoldTensorCastOfOutputIntoForallOp` for multi-result scf.forall (PR #173271)
Mehdi Amini
llvmlistbot at llvm.org
Sun Dec 28 07:54:45 PST 2025
================
@@ -2028,6 +2028,43 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// -----
+// CHECK-LABEL: func.func @fold_tensor_cast_into_forall_with_multiple_result(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 16 : index
+// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<32xf32>
+// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FORALL_0:.*]]:2 = scf.forall (%[[VAL_0:.*]]) in (4) shared_outs(%[[VAL_1:.*]] = %[[EMPTY_0]], %[[VAL_2:.*]] = %[[EMPTY_1]]) -> (tensor<32xf32>, tensor<64xf32>) {
+// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_1]] : index
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ARG0]] into %[[VAL_2]]{{\[}}%[[MULI_1]]] [16] [1] : tensor<16xf32> into tensor<64xf32>
+// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[VAL_1]]{{\[}}%[[MULI_0]]] [8] [1] : tensor<8xf32> into tensor<32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[CAST_0:.*]] = tensor.cast %[[FORALL_0]]#0 : tensor<32xf32> to tensor<?xf32>
+// CHECK: return %[[CAST_0]], %[[FORALL_0]]#1 : tensor<?xf32>, tensor<64xf32>
+// CHECK: }
----------------
joker-eph wrote:
Please restrict the check to the minimum thing to match (that should be a handful of checks)
https://github.com/llvm/llvm-project/pull/173271
More information about the Mlir-commits
mailing list