[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon May 13 06:19:09 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/87794
>From ed9183fcc039035c182a3c2771cb209cfd16de5a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Apr 2024 10:44:33 +0000
Subject: [PATCH 1/3] [mlir][vector] Teach `TransferOptimization` to forward
masked stores
This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.
* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 31 +++++++++--
.../Dialect/Vector/vector-transferop-opt.mlir | 52 ++++++++++++++++++-
2 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..8d4eac1324d40 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+ vector::TransferReadOp read) {
+ if (!defWrite.getMask() && !read.getMask())
+ return true; // Success: No masks (values will be the same).
+ // Check for constant splats. These will be the same value if the read is
+ // masked (and padded with the splat value), and the write is unmasked or has
+ // the same mask.
+ bool couldBeSameSplatValue =
+ read.getMask() &&
+ (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+ if (!couldBeSameSplatValue)
+ return false;
+ DenseElementsAttr splatAttr;
+ if (!matchPattern(defWrite.getVector(),
+ m_Constant<DenseElementsAttr>(&splatAttr)) ||
+ !splatAttr.isSplat()) {
+ return false;
+ }
+ Attribute padAttr;
+ if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+ return false;
+ return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
- !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+ return !defWrite.hasOutOfBoundsDim() &&
+ defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.getPermutationMap() == read.getPermutationMap();
+ defWrite.getPermutationMap() == read.getPermutationMap() &&
+ couldBeSameValueWithMasking(defWrite, read);
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..2c8f105cd5c14 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
>From 89931de90bafd501f8b194facaba88d7e83919e1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Apr 2024 13:32:48 +0000
Subject: [PATCH 2/3] Fixup - remove unused ops
---
mlir/test/Dialect/Vector/vector-transferop-opt.mlir | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 2c8f105cd5c14..b2fa5c68c17a3 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -374,7 +374,6 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- %vscale = vector.vscale
vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
%0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
@@ -400,7 +399,6 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- %vscale = vector.vscale
vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
%0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
>From 884e977b51f8bdb05860ed5095f2c593573ab53d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Apr 2024 13:51:14 +0000
Subject: [PATCH 3/3] More docs and tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++-
.../Dialect/Vector/vector-transferop-opt.mlir | 72 ++++++++++++++++---
2 files changed, 70 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d4eac1324d40..df9c2c2e5dcf5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,10 +170,16 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+/// Returns true if the value written by `defWrite` could be the same as the
+/// value read by `read`. Note: True is 'could be' not 'definitely' (as this
+/// simply looks at the masks and the value written). For a definite answer use
+/// `checkSameValueRAW()` -- which calls this function.
static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- if (!defWrite.getMask() && !read.getMask())
- return true; // Success: No masks (values will be the same).
+ if (!defWrite.getMask() && !read.getMask()) {
+ // Success: No masks (values could be the same).
+ return true;
+ }
// Check for constant splats. These will be the same value if the read is
// masked (and padded with the splat value), and the write is unmasked or has
// the same mask.
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index b2fa5c68c17a3..74fca321cd442 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -362,20 +362,47 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
}
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
// CHECK: }
// CHECK: vector.transfer_write
// CHECK: return
func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
- %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
- %cst_f32 = arith.constant 0.0 : f32
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
- %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Here the read can be eliminated but not the write (due to mismatched masks).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+// CHECK: vector.transfer_write %[[SPLAT]]
+// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
scf.yield %1 : vector<[8]x[8]xf32>
@@ -386,21 +413,21 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
// Negative test, the padding does not match the constant splat, so we can't
// forward the store.
-// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: vector.transfer_write
// CHECK: return
-func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
- %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
- %cst_f32 = arith.constant 1.0 : f32
+func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 1.0 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
- %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
scf.yield %1 : vector<[8]x[8]xf32>
@@ -408,3 +435,28 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
return
}
+
+// Negative test, the masks don't match between the read and write, so we can't
+// foward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list