[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Apr 5 08:14:13 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/87794
This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.
* For SVE/SME (without peeling) this occurs when you have a `linalg.fill` preceding a `linalg.matmul`.
>From 5790171e1d8e3ff4f72e197b9e4e08d249c69104 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Apr 2024 10:44:33 +0000
Subject: [PATCH] [mlir][vector] Teach `TransferOptimization` to forward masked
stores
This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.
* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 31 +++++++++--
.../Dialect/Vector/vector-transferop-opt.mlir | 52 ++++++++++++++++++-
2 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..1dacafe3d7fabc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+ vector::TransferReadOp read) {
+ if (!defWrite.getMask() && !read.getMask())
+ return true; // Success: No masks (values will be the same).
+ // Check for constant splats. These will be the same value if the read is
+ // masked (and padded with the splat value), and the write is unmasked or has
+ // the same mask.
+ bool couldBeSameSplatValue =
+ read.getMask() &&
+ (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+ if (!couldBeSameSplatValue)
+ return false;
+ DenseElementsAttr splatAttr;
+ if (!matchPattern(defWrite.getVector(),
+ m_Constant<DenseElementsAttr>(&splatAttr)) ||
+ !splatAttr.isSplat()) {
+ return false;
+ }
+ Attribute padAttr;
+ if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+ return false;
+ return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
- !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+ return !defWrite.hasOutOfBoundsDim() &&
+ defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.getPermutationMap() == read.getPermutationMap();
+ defWrite.getPermutationMap() == read.getPermutationMap() &&
+ couldBeSameValueWithMasking(defWrite, read);
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..2c8f105cd5c14b 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list