[Mlir-commits] [mlir] [MLIR][Affine] Fix affine loop fusion with vector ops #115849, #120227 (PR #122799)

Uday Bondhugula llvmlistbot at llvm.org
Mon Feb 3 00:50:20 PST 2025


================
@@ -285,3 +286,147 @@ module {
     spirv.ReturnValue %3 : !spirv.array<8192 x f32>
   }
 }
+
+// -----
+
+// Basic test for not fusing loops where a vector load depends on 
+// the entire result of a previous loop. store shape < load shape
+
+// CHECK-LABEL: func @should_not_fuse_across_memref_store_load_bounds
+func.func @should_not_fuse_across_memref_store_load_bounds() {
+  %a = memref.alloc() : memref<64x512xf32>
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32>
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+      %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+      %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+      affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+  }
+
+  return
+}
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8
+// CHECK: %[[lhs:.*]] = affine.vector_load %[[a]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:.*]] = affine.vector_load %[[b]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], %[[c]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: affine.for %[[j_2:.*]] = 0 to 8
+// CHECK: %[[lhs_2:.*]] = affine.vector_load %[[c]][0, 0] : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: %[[rhs_2:.*]] = affine.vector_load %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x512xf32>
+// CHECK: affine.vector_store %[[res_2]], %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: return
+
+// -----
+
+// Basic test for not fusing loops where the dependencies involve
+// an affine vector store and affine loads
+
+// CHECK-LABEL: func @should_not_fuse_vector_store_non_vector_load
+func.func @should_not_fuse_vector_store_non_vector_load() -> memref<64x4096xf32> {
+  %c0 = arith.constant 0 : index
+  %a = memref.alloc() : memref<64x512xf32> 
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32> 
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+    affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %k = 0 to 64 {
+    affine.for %m = 0 to 4096 {
+      affine.for %l = 0 to 512 {
+        %lhs = affine.load %c[%k, %l] : memref<64x512xf32>
+        %rhs = affine.load %d[%k, %m] : memref<64x4096xf32>
+        %res = arith.subf %lhs, %rhs : f32
+        affine.store %res, %d[%k, %m] : memref<64x4096xf32>
+      }
+    }
+  }
+
+  return %d : memref<64x4096xf32>
+}
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8 {
+// CHECK:   %[[lhs:.*]] = affine.vector_load %[[a]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK:   %[[rhs:.*]] = affine.vector_load %[[b]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK:   %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK:   affine.vector_store %[[res]], %[[c]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %[[k:.*]] = 0 to 64 {
+// CHECK:   affine.for %[[l:.*]] = 0 to 4096 {
+// CHECK:     affine.for %[[m:.*]] = 0 to 512 {
+// CHECK:       %[[lhs_2:.*]] = affine.load %[[c]][%[[k]], %[[m]]] : memref<64x512xf32>
+// CHECK:       %[[rhs_2:.*]] = affine.load %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK:       %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : f32
+// CHECK:       affine.store %[[res_2]], %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
+// CHECK: return %[[d]] : memref<64x4096xf32>
+
+// -----
+
+// Basic test for fusing loops where a vector load depends on 
+// the partial result of a previous loop. store shape > load shape
+
+// CHECK-LABEL: func @should_fuse_across_memref_store_load_bounds
+func.func @should_fuse_across_memref_store_load_bounds() -> memref<64x4096xf32> {
+  %c0 = arith.constant 0 : index
+  %a = memref.alloc() : memref<64x512xf32> 
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32> 
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+      %res = arith.subf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+  }
+  return %d : memref<64x4096xf32>
+}
----------------
bondhugula wrote:

All test cases are assuming single element private memrefs and so the vector shape would be the private memref shape. This isn't true in the general case. You'll need to multiply the vector shape with the private memref region dimension sizes computed. Test cases will have to be augmented - see another test case where the private memref isn't unit sized.

https://github.com/llvm/llvm-project/pull/122799


More information about the Mlir-commits mailing list