[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 05:54:31 PST 2024
================
@@ -310,49 +310,48 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
// -----
-func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
- %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %C_elem : f32
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @nested_fuse
-// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
-// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C1FP:%.*]] = arith.constant 1.
----------------
fabrizio-indirli wrote:
**✓** Done
https://github.com/llvm/llvm-project/pull/79187
More information about the Mlir-commits
mailing list