[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)

Benjamin Maxwell llvmlistbot at llvm.org
Mon May 13 06:19:09 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/87794

>From ed9183fcc039035c182a3c2771cb209cfd16de5a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Apr 2024 10:44:33 +0000
Subject: [PATCH 1/3] [mlir][vector] Teach `TransferOptimization` to forward
 masked stores

This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.

* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 31 +++++++++--
 .../Dialect/Vector/vector-transferop-opt.mlir | 52 ++++++++++++++++++-
 2 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..8d4eac1324d40 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+                                        vector::TransferReadOp read) {
+  if (!defWrite.getMask() && !read.getMask())
+    return true; // Success: No masks (values will be the same).
+  // Check for constant splats. These will be the same value if the read is
+  // masked (and padded with the splat value), and the write is unmasked or has
+  // the same mask.
+  bool couldBeSameSplatValue =
+      read.getMask() &&
+      (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+  if (!couldBeSameSplatValue)
+    return false;
+  DenseElementsAttr splatAttr;
+  if (!matchPattern(defWrite.getVector(),
+                    m_Constant<DenseElementsAttr>(&splatAttr)) ||
+      !splatAttr.isSplat()) {
+    return false;
+  }
+  Attribute padAttr;
+  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+    return false;
+  return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                      vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
-         !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+  return !defWrite.hasOutOfBoundsDim() &&
+         defWrite.getIndices() == read.getIndices() &&
          defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.getPermutationMap() == read.getPermutationMap();
+         defWrite.getPermutationMap() == read.getPermutationMap() &&
+         couldBeSameValueWithMasking(defWrite, read);
 }
 
 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..2c8f105cd5c14 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}

>From 89931de90bafd501f8b194facaba88d7e83919e1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Apr 2024 13:32:48 +0000
Subject: [PATCH 2/3] Fixup - remove unused ops

---
 mlir/test/Dialect/Vector/vector-transferop-opt.mlir | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 2c8f105cd5c14..b2fa5c68c17a3 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -374,7 +374,6 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  %vscale = vector.vscale
   vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
@@ -400,7 +399,6 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  %vscale = vector.vscale
   vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {

>From 884e977b51f8bdb05860ed5095f2c593573ab53d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Apr 2024 13:51:14 +0000
Subject: [PATCH 3/3] More docs and tests

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 10 ++-
 .../Dialect/Vector/vector-transferop-opt.mlir | 72 ++++++++++++++++---
 2 files changed, 70 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d4eac1324d40..df9c2c2e5dcf5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,10 +170,16 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+/// Returns true if the value written by `defWrite` could be the same as the
+/// value read by `read`. Note: True is 'could be' not 'definitely' (as this
+/// simply looks at the masks and the value written). For a definite answer use
+/// `checkSameValueRAW()` -- which calls this function.
 static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
                                         vector::TransferReadOp read) {
-  if (!defWrite.getMask() && !read.getMask())
-    return true; // Success: No masks (values will be the same).
+  if (!defWrite.getMask() && !read.getMask()) {
+    // Success: No masks (values could be the same).
+    return true;
+  }
   // Check for constant splats. These will be the same value if the read is
   // masked (and padded with the splat value), and the write is unmasked or has
   // the same mask.
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index b2fa5c68c17a3..74fca321cd442 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -362,20 +362,47 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
 }
 
 // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//       CHECK:   %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
 //   CHECK-NOT:   vector.transfer_write
 //   CHECK-NOT:   vector.transfer_read
 //       CHECK:   scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[SPLAT]])
 //       CHECK:   }
 //       CHECK:   vector.transfer_write
 //       CHECK:   return
 func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
-  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
-  %cst_f32 = arith.constant 0.0 : f32
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
-  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Here the read can be eliminated but not the write (due to mismatched masks).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
+//       CHECK:   %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+//       CHECK:   vector.transfer_write %[[SPLAT]]
+//       CHECK:   scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[SPLAT]])
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
     %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
     scf.yield %1 : vector<[8]x[8]xf32>
@@ -386,21 +413,21 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
 
 // Negative test, the padding does not match the constant splat, so we can't
 // forward the store.
-// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
 //       CHECK:   vector.transfer_write
 //       CHECK:   vector.transfer_read
 //       CHECK:   scf.for
 //       CHECK:   }
 //       CHECK:   vector.transfer_write
 //       CHECK:   return
-func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
-  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
-  %cst_f32 = arith.constant 1.0 : f32
+func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 1.0 : f32
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
-  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
     %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
     scf.yield %1 : vector<[8]x[8]xf32>
@@ -408,3 +435,28 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
   vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   return
 }
+
+// Negative test, the masks don't match between the read and write, so we can't
+// foward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}



More information about the Mlir-commits mailing list