[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)

Diego Caballero llvmlistbot at llvm.org
Wed Apr 17 00:30:25 PDT 2024

@@ -360,3 +360,51 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
dcaballe wrote:

Add test with different masks?


More information about the Mlir-commits mailing list