[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 5 08:14:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.

* For SVE/SME (without peeling) this occurs when you have a `linalg.fill` preceding a `linalg.matmul`.

---
Full diff: https://github.com/llvm/llvm-project/pull/87794.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+28-3) 
- (modified) mlir/test/Dialect/Vector/vector-transferop-opt.mlir (+51-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..1dacafe3d7fabc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+                                        vector::TransferReadOp read) {
+  if (!defWrite.getMask() && !read.getMask())
+    return true; // Success: No masks (values will be the same).
+  // Check for constant splats. These will be the same value if the read is
+  // masked (and padded with the splat value), and the write is unmasked or has
+  // the same mask.
+  bool couldBeSameSplatValue =
+      read.getMask() &&
+      (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+  if (!couldBeSameSplatValue)
+    return false;
+  DenseElementsAttr splatAttr;
+  if (!matchPattern(defWrite.getVector(),
+                    m_Constant<DenseElementsAttr>(&splatAttr)) ||
+      !splatAttr.isSplat()) {
+    return false;
+  }
+  Attribute padAttr;
+  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+    return false;
+  return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                      vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
-         !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+  return !defWrite.hasOutOfBoundsDim() &&
+         defWrite.getIndices() == read.getIndices() &&
          defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.getPermutationMap() == read.getPermutationMap();
+         defWrite.getPermutationMap() == read.getPermutationMap() &&
+         couldBeSameValueWithMasking(defWrite, read);
 }
 
 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..2c8f105cd5c14b 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87794


More information about the Mlir-commits mailing list