[Mlir-commits] [mlir] ca02f36 - [mlir][vector] Teach `TransferOptimization` to forward masked stores (#87794)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 16 01:52:26 PDT 2024
Author: Benjamin Maxwell
Date: 2024-05-16T09:52:21+01:00
New Revision: ca02f36bacaec58071a26ff65329fbab5526d1d7
URL: https://github.com/llvm/llvm-project/commit/ca02f36bacaec58071a26ff65329fbab5526d1d7
DIFF: https://github.com/llvm/llvm-project/commit/ca02f36bacaec58071a26ff65329fbab5526d1d7.diff
LOG: [mlir][vector] Teach `TransferOptimization` to forward masked stores (#87794)
This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.
* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/vector-transferop-opt.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..58951641d33ce 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,43 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+/// Check if `write` is of a constant splat and the masked `read` is padded with
+/// the same splat value -- meaning it could be the same value as the initial
+/// constant splat.
+static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
+ vector::TransferReadOp read) {
+ auto readMask = read.getMask();
+ auto writeMask = write.getMask();
+ // Check if the masks are consistent. The splat value could be the same if the
+ // read is masked (and padded with the splat value), and the write is unmasked
+ // or has the same mask. Note this does not allow the case where the write is
+ // masked and the read is unmasked, as then the read could be of more elements
+ // than the write (which may not be the same value).
+ bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
+ if (!couldBeSameSplat)
+ return false;
+ // Check for constant splat (as the source of the write).
+ DenseElementsAttr splatAttr;
+ if (!matchPattern(write.getVector(),
+ m_Constant<DenseElementsAttr>(&splatAttr)) ||
+ !splatAttr.isSplat()) {
+ return false;
+ }
+ // The padding of the read and the constant splat value must be the same.
+ Attribute padAttr;
+ if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+ return false;
+ return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
- !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+ return !defWrite.hasOutOfBoundsDim() &&
+ defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.getPermutationMap() == read.getPermutationMap();
+ defWrite.getPermutationMap() == read.getPermutationMap() &&
+ ((!defWrite.getMask() && !read.getMask()) ||
+ isSplatWriteConsistentWithMaskedRead(defWrite, read));
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..0719c0dd17427 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,128 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Here the read can be eliminated but not the write (due to mismatched masks).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+// CHECK: vector.transfer_write %[[SPLAT]]
+// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the masks don't match between the read and write, so we can't
+// foward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, here the write is masked but the read is unmasked. We can't
+// forward the store (as the write could be of less elements then the read).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list