[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Apr 15 06:35:01 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/87794

>From d0b36e6a4fdb286aab1090ef412b1fdbad84ea3a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Apr 2024 10:44:33 +0000
Subject: [PATCH 1/2] [mlir][vector] Teach `TransferOptimization` to forward
 masked stores

This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.

* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 31 +++++++++--
 .../Dialect/Vector/vector-transferop-opt.mlir | 52 ++++++++++++++++++-
 2 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..1dacafe3d7fabc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+                                        vector::TransferReadOp read) {
+  if (!defWrite.getMask() && !read.getMask())
+    return true; // Success: No masks (values will be the same).
+  // Check for constant splats. These will be the same value if the read is
+  // masked (and padded with the splat value), and the write is unmasked or has
+  // the same mask.
+  bool couldBeSameSplatValue =
+      read.getMask() &&
+      (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+  if (!couldBeSameSplatValue)
+    return false;
+  DenseElementsAttr splatAttr;
+  if (!matchPattern(defWrite.getVector(),
+                    m_Constant<DenseElementsAttr>(&splatAttr)) ||
+      !splatAttr.isSplat()) {
+    return false;
+  }
+  Attribute padAttr;
+  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+    return false;
+  return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                      vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
-         !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+  return !defWrite.hasOutOfBoundsDim() &&
+         defWrite.getIndices() == read.getIndices() &&
          defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.getPermutationMap() == read.getPermutationMap();
+         defWrite.getPermutationMap() == read.getPermutationMap() &&
+         couldBeSameValueWithMasking(defWrite, read);
 }
 
 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..2c8f105cd5c14b 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}

>From e4435c359a917e91cc6d2f29d6d9eec5c9a83c7e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Apr 2024 13:32:48 +0000
Subject: [PATCH 2/2] Fixup - remove unused ops

---
 mlir/test/Dialect/Vector/vector-transferop-opt.mlir | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 2c8f105cd5c14b..b2fa5c68c17a31 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -374,7 +374,6 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  %vscale = vector.vscale
   vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
@@ -400,7 +399,6 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c512 = arith.constant 512 : index
-  %vscale = vector.vscale
   vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
   %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {



More information about the Mlir-commits mailing list