[Mlir-commits] [mlir] ca02f36 - [mlir][vector] Teach `TransferOptimization` to forward masked stores (#87794)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 16 01:52:26 PDT 2024


Author: Benjamin Maxwell
Date: 2024-05-16T09:52:21+01:00
New Revision: ca02f36bacaec58071a26ff65329fbab5526d1d7

URL: https://github.com/llvm/llvm-project/commit/ca02f36bacaec58071a26ff65329fbab5526d1d7
DIFF: https://github.com/llvm/llvm-project/commit/ca02f36bacaec58071a26ff65329fbab5526d1d7.diff

LOG: [mlir][vector] Teach `TransferOptimization` to forward masked stores (#87794)

This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.

* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..58951641d33ce 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,43 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+/// Check if `write` is of a constant splat and the masked `read` is padded with
+/// the same splat value -- meaning it could be the same value as the initial
+/// constant splat.
+static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
+                                                 vector::TransferReadOp read) {
+  auto readMask = read.getMask();
+  auto writeMask = write.getMask();
+  // Check if the masks are consistent. The splat value could be the same if the
+  // read is masked (and padded with the splat value), and the write is unmasked
+  // or has the same mask. Note this does not allow the case where the write is
+  // masked and the read is unmasked, as then the read could be of more elements
+  // than the write (which may not be the same value).
+  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
+  if (!couldBeSameSplat)
+    return false;
+  // Check for constant splat (as the source of the write).
+  DenseElementsAttr splatAttr;
+  if (!matchPattern(write.getVector(),
+                    m_Constant<DenseElementsAttr>(&splatAttr)) ||
+      !splatAttr.isSplat()) {
+    return false;
+  }
+  // The padding of the read and the constant splat value must be the same.
+  Attribute padAttr;
+  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+    return false;
+  return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                      vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
-         !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+  return !defWrite.hasOutOfBoundsDim() &&
+         defWrite.getIndices() == read.getIndices() &&
          defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.getPermutationMap() == read.getPermutationMap();
+         defWrite.getPermutationMap() == read.getPermutationMap() &&
+         ((!defWrite.getMask() && !read.getMask()) ||
+          isSplatWriteConsistentWithMaskedRead(defWrite, read));
 }
 
 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,

diff  --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..0719c0dd17427 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,128 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//       CHECK:   %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[SPLAT]])
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Here the read can be eliminated but not the write (due to mismatched masks).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
+//       CHECK:   %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+//       CHECK:   vector.transfer_write %[[SPLAT]]
+//       CHECK:   scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[SPLAT]])
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the masks don't match between the read and write, so we can't
+// foward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, here the write is masked but the read is unmasked. We can't
+// forward the store (as the write could be of less elements then the read).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %read_padding = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}


        


More information about the Mlir-commits mailing list