[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to forward masked stores (PR #87794)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed May 15 13:58:12 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/87794
>From 7941a704a1dc21bf52f90ab8126a0571377ca0e7 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Apr 2024 10:44:33 +0000
Subject: [PATCH 1/6] [mlir][vector] Teach `TransferOptimization` to forward
masked stores
This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.
* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 31 +++++++++--
.../Dialect/Vector/vector-transferop-opt.mlir | 52 ++++++++++++++++++-
2 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..8d4eac1324d40 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+ vector::TransferReadOp read) {
+ if (!defWrite.getMask() && !read.getMask())
+ return true; // Success: No masks (values will be the same).
+ // Check for constant splats. These will be the same value if the read is
+ // masked (and padded with the splat value), and the write is unmasked or has
+ // the same mask.
+ bool couldBeSameSplatValue =
+ read.getMask() &&
+ (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+ if (!couldBeSameSplatValue)
+ return false;
+ DenseElementsAttr splatAttr;
+ if (!matchPattern(defWrite.getVector(),
+ m_Constant<DenseElementsAttr>(&splatAttr)) ||
+ !splatAttr.isSplat()) {
+ return false;
+ }
+ Attribute padAttr;
+ if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+ return false;
+ return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
- !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+ return !defWrite.hasOutOfBoundsDim() &&
+ defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.getPermutationMap() == read.getPermutationMap();
+ defWrite.getPermutationMap() == read.getPermutationMap() &&
+ couldBeSameValueWithMasking(defWrite, read);
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..2c8f105cd5c14 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
>From 6fbdfb9b78b0150fe8703645d57edfa8595538e1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Apr 2024 13:32:48 +0000
Subject: [PATCH 2/6] Fixup - remove unused ops
---
mlir/test/Dialect/Vector/vector-transferop-opt.mlir | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 2c8f105cd5c14..b2fa5c68c17a3 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -374,7 +374,6 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- %vscale = vector.vscale
vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
%0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
@@ -400,7 +399,6 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- %vscale = vector.vscale
vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
%0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
>From 8f5f2c08be7bc53087fcd385e4981a983677d0dc Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Apr 2024 13:51:14 +0000
Subject: [PATCH 3/6] More docs and tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++-
.../Dialect/Vector/vector-transferop-opt.mlir | 72 ++++++++++++++++---
2 files changed, 70 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d4eac1324d40..df9c2c2e5dcf5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,10 +170,16 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+/// Returns true if the value written by `defWrite` could be the same as the
+/// value read by `read`. Note: True is 'could be' not 'definitely' (as this
+/// simply looks at the masks and the value written). For a definite answer use
+/// `checkSameValueRAW()` -- which calls this function.
static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- if (!defWrite.getMask() && !read.getMask())
- return true; // Success: No masks (values will be the same).
+ if (!defWrite.getMask() && !read.getMask()) {
+ // Success: No masks (values could be the same).
+ return true;
+ }
// Check for constant splats. These will be the same value if the read is
// masked (and padded with the splat value), and the write is unmasked or has
// the same mask.
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index b2fa5c68c17a3..74fca321cd442 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -362,20 +362,47 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
}
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
// CHECK: }
// CHECK: vector.transfer_write
// CHECK: return
func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
- %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
- %cst_f32 = arith.constant 0.0 : f32
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
- %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Here the read can be eliminated but not the write (due to mismatched masks).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
+// CHECK: vector.transfer_write %[[SPLAT]]
+// CHECK: scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
scf.yield %1 : vector<[8]x[8]xf32>
@@ -386,21 +413,21 @@ func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf
// Negative test, the padding does not match the constant splat, so we can't
// forward the store.
-// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: vector.transfer_write
// CHECK: return
-func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
- %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
- %cst_f32 = arith.constant 1.0 : f32
+func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 1.0 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
- vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
- %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
scf.yield %1 : vector<[8]x[8]xf32>
@@ -408,3 +435,28 @@ func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : mem
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
return
}
+
+// Negative test, the masks don't match between the read and write, so we can't
+// foward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
>From 30267edd342cfb7d36424b52c262abd1804d9238 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 15 May 2024 14:21:09 +0000
Subject: [PATCH 4/6] Fixups
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index df9c2c2e5dcf5..80b8c2901cd04 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -176,17 +176,17 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
/// `checkSameValueRAW()` -- which calls this function.
static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- if (!defWrite.getMask() && !read.getMask()) {
+ auto readMask = read.getMask();
+ auto writeMask = defWrite.getMask();
+ if (!writeMask && !readMask) {
// Success: No masks (values could be the same).
return true;
}
// Check for constant splats. These will be the same value if the read is
// masked (and padded with the splat value), and the write is unmasked or has
// the same mask.
- bool couldBeSameSplatValue =
- read.getMask() &&
- (!defWrite.getMask() || defWrite.getMask() == read.getMask());
- if (!couldBeSameSplatValue)
+ bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
+ if (!couldBeSameSplat)
return false;
DenseElementsAttr splatAttr;
if (!matchPattern(defWrite.getVector(),
>From 219b35fdc943829c769f019057ea12b15294c864 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 15 May 2024 15:22:09 +0000
Subject: [PATCH 5/6] More docs and tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++-
.../Dialect/Vector/vector-transferop-opt.mlir | 27 ++++++++++++++++++-
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 80b8c2901cd04..03352a691642e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -184,7 +184,9 @@ static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
}
// Check for constant splats. These will be the same value if the read is
// masked (and padded with the splat value), and the write is unmasked or has
- // the same mask.
+ // the same mask. Note this does not allow the case where the write is masked
+ // and the read is unmasked, as then the read could be of more elements than
+ // the write (which may not be the same value).
bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
if (!couldBeSameSplat)
return false;
@@ -194,6 +196,7 @@ static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
!splatAttr.isSplat()) {
return false;
}
+ // The padding of the read and the constant splat value must be the same.
Attribute padAttr;
if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
return false;
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 74fca321cd442..0719c0dd17427 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -447,7 +447,7 @@ func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : m
// CHECK: return
func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
- %read_padding = arith.constant 1.0 : f32
+ %read_padding = arith.constant 0.0 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
@@ -460,3 +460,28 @@ func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : m
vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
return
}
+
+// Negative test, here the write is masked but the read is unmasked. We can't
+// forward the store (as the write could be of less elements then the read).
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %read_padding = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %read_padding {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
>From 7bca6d12f9cfdddbf9cf0213894ec754cc75779f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 15 May 2024 20:57:18 +0000
Subject: [PATCH 6/6] Naming things
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 03352a691642e..176e37a6fa6d4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,14 +170,15 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
-/// Returns true if the value written by `defWrite` could be the same as the
-/// value read by `read`. Note: True is 'could be' not 'definitely' (as this
-/// simply looks at the masks and the value written). For a definite answer use
-/// `checkSameValueRAW()` -- which calls this function.
-static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
- vector::TransferReadOp read) {
+/// Returns true if both `write` and `read` have no masks, or the write is of a
+/// constant splat and the masks + padding means the read will always be the
+/// same constant splat. Note: This assumes other properties of both `read` and
+/// `write` have already been checked (e.g. indices, source, etc). For a more
+/// general check use `checkSameValueRAW`.
+static bool hasNoMasksOrConsistentSplatAndPadding(vector::TransferWriteOp write,
+ vector::TransferReadOp read) {
auto readMask = read.getMask();
- auto writeMask = defWrite.getMask();
+ auto writeMask = write.getMask();
if (!writeMask && !readMask) {
// Success: No masks (values could be the same).
return true;
@@ -191,7 +192,7 @@ static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
if (!couldBeSameSplat)
return false;
DenseElementsAttr splatAttr;
- if (!matchPattern(defWrite.getVector(),
+ if (!matchPattern(write.getVector(),
m_Constant<DenseElementsAttr>(&splatAttr)) ||
!splatAttr.isSplat()) {
return false;
@@ -209,7 +210,7 @@ bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.getPermutationMap() == read.getPermutationMap() &&
- couldBeSameValueWithMasking(defWrite, read);
+ hasNoMasksOrConsistentSplatAndPadding(defWrite, read);
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
More information about the Mlir-commits
mailing list