[Mlir-commits] [mlir] [mlir][vector] Teach `TransferOptimization` to look through trivial aliases (PR #87805)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon May 13 05:17:04 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/87805
>From e3b8fb198d4118804f5ff4b461f5c7f3e9461eac Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 5 Apr 2024 15:55:33 +0000
Subject: [PATCH 1/3] [mlir][vector] Teach `TransferOptimization` look through
trivial aliases
This allows `TransferOptimization` to eliminate and forward stores that
are to trivial aliases (rather than just to identical memref values).
A trivial aliases is (currently) defined as:
1. A `memref.cast`
2. A `memref.subview` with a zero offset and unit strides
3. A chain of 1 and 2
---
.../Transforms/VectorTransferOpTransforms.cpp | 66 ++++++++++++++-----
.../Dialect/Vector/vector-transferop-opt.mlir | 32 ++++++++-
2 files changed, 79 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0ffef6aabccc1..87a03e2f87476 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -88,6 +88,46 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
return false;
}
+/// Walk up the source chain until an operation that changes/defines the view of
+/// memory is found (i.e. skip operations that alias the entire view).
+Value skipFullyAliasingOperations(Value source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
+ subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
+ // A `memref.subview` with an all zero offset, and all unit strides, still
+ // points to the same memory.
+ source = subViewOp.getSource();
+ } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+ // A `memref.cast` still points to the same memory.
+ source = castOp.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
+/// Checks if two (memref) values are are the same, or are statically known to
+/// alias the same region of memory.
+bool isSameViewOrTrivialAlias(Value a, Value b) {
+ return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
+}
+
+/// Walk up the source chain until something an op other than a `memref.subview`
+/// or `memref.cast` is found.
+Value skipSubViewsAndCasts(Value source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
+ source = subView.getSource();
+ } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
+ source = cast.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
/// For transfer_write to overwrite fully another transfer_write must:
/// 1. Access the same memref with the same indices and vector type.
/// 2. Post-dominate the other transfer_write operation.
@@ -104,10 +144,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
- Value source = write.getSource();
- // Skip subview ops.
- while (auto subView = source.getDefiningOp<memref::SubViewOp>())
- source = subView.getSource();
+ Value source = skipSubViewsAndCasts(write.getSource());
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -116,8 +153,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
- users.append(subView->getUsers().begin(), subView->getUsers().end());
+ if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+ users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
@@ -126,7 +163,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
- if (write.getSource() == nextWrite.getSource() &&
+ if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +228,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
- Value source = read.getSource();
- // Skip subview ops.
- while (auto subView = source.getDefiningOp<memref::SubViewOp>())
- source = subView.getSource();
+ Value source = skipSubViewsAndCasts(read.getSource());
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -203,12 +237,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
- users.append(subView->getUsers().begin(), subView->getUsers().end());
- continue;
- }
- if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
- users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
+ if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+ users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +251,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(read.getOperation()),
/*testDynamicValueUsingBounds=*/true))
continue;
- if (write.getSource() == read.getSource() &&
+ if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89..e47d26940afa2 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,33 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// Here each read/write is to a different subview, but they all point to exact
+// same bit of memory (just through casts and subviews with unit strides and
+// zero offsets).
+// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_and_eliminate_stores_through_trivial_aliases(
+ %buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
+) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %cst = arith.constant 0.0 : f32
+ vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
+ %subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
+ %23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
+ %24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
+ scf.yield %24 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
+ return
+}
>From 246f8b3db5720604d3d607004fd2b5d36a225935 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 16 Apr 2024 15:48:23 +0000
Subject: [PATCH 2/3] Fix typo
---
.../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 87a03e2f87476..c57c52f6c1324 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -107,8 +107,8 @@ Value skipFullyAliasingOperations(Value source) {
return source;
}
-/// Checks if two (memref) values are are the same, or are statically known to
-/// alias the same region of memory.
+/// Checks if two (memref) values are the same or are statically known to alias
+/// the same region of memory.
bool isSameViewOrTrivialAlias(Value a, Value b) {
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
}
>From d3aecb2789cc7cf9c89775cf05113cc147fcf863 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 13 May 2024 12:14:40 +0000
Subject: [PATCH 3/3] Move helpers to MemRefUtils
---
.../mlir/Dialect/MemRef/Utils/MemRefUtils.h | 17 ++++++
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 30 ++++++++++
.../Transforms/VectorTransferOpTransforms.cpp | 55 ++++---------------
3 files changed, 58 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 46003ed846869..1c094c1c1328a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -22,6 +22,9 @@ namespace mlir {
class MemRefType;
+/// A value with a memref type.
+using MemrefValue = TypedValue<BaseMemRefType>;
+
namespace memref {
/// Returns true, if the memref type has static shapes and represents a
@@ -93,6 +96,20 @@ computeStridesIRBlock(Location loc, OpBuilder &builder,
return computeSuffixProductIRBlock(loc, builder, sizes);
}
+/// Walk up the source chain until an operation that changes/defines the view of
+/// memory is found (i.e. skip operations that alias the entire view).
+MemrefValue skipFullyAliasingOperations(MemrefValue source);
+
+/// Checks if two (memref) values are the same or are statically known to alias
+/// the same region of memory.
+inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
+ return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
+}
+
+/// Walk up the source chain until something an op other than a `memref.subview`
+/// or `memref.cast` is found.
+MemrefValue skipSubViewsAndCasts(MemrefValue source);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index c93e5a9dcd39f..05d5ca2ce12f4 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -178,5 +178,35 @@ computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
}
+MemrefValue skipFullyAliasingOperations(MemrefValue source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
+ subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
+ // A `memref.subview` with an all zero offset, and all unit strides, still
+ // points to the same memory.
+ source = cast<MemrefValue>(subViewOp.getSource());
+ } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+ // A `memref.cast` still points to the same memory.
+ source = castOp.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
+MemrefValue skipSubViewsAndCasts(MemrefValue source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
+ source = cast<MemrefValue>(subView.getSource());
+ } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
+ source = cast.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c57c52f6c1324..997b56a1ce142 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -88,46 +89,6 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
return false;
}
-/// Walk up the source chain until an operation that changes/defines the view of
-/// memory is found (i.e. skip operations that alias the entire view).
-Value skipFullyAliasingOperations(Value source) {
- while (auto op = source.getDefiningOp()) {
- if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
- subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
- // A `memref.subview` with an all zero offset, and all unit strides, still
- // points to the same memory.
- source = subViewOp.getSource();
- } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
- // A `memref.cast` still points to the same memory.
- source = castOp.getSource();
- } else {
- return source;
- }
- }
- return source;
-}
-
-/// Checks if two (memref) values are the same or are statically known to alias
-/// the same region of memory.
-bool isSameViewOrTrivialAlias(Value a, Value b) {
- return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
-}
-
-/// Walk up the source chain until something an op other than a `memref.subview`
-/// or `memref.cast` is found.
-Value skipSubViewsAndCasts(Value source) {
- while (auto op = source.getDefiningOp()) {
- if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
- source = subView.getSource();
- } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
- source = cast.getSource();
- } else {
- return source;
- }
- }
- return source;
-}
-
/// For transfer_write to overwrite fully another transfer_write must:
/// 1. Access the same memref with the same indices and vector type.
/// 2. Post-dominate the other transfer_write operation.
@@ -144,7 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
- Value source = skipSubViewsAndCasts(write.getSource());
+ Value source =
+ memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -163,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
- if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
+ if (memref::isSameViewOrTrivialAlias(
+ cast<MemrefValue>(nextWrite.getSource()),
+ cast<MemrefValue>(write.getSource())) &&
checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
@@ -228,7 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
- Value source = skipSubViewsAndCasts(read.getSource());
+ Value source =
+ memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -251,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(read.getOperation()),
/*testDynamicValueUsingBounds=*/true))
continue;
- if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
+ if (memref::isSameViewOrTrivialAlias(
+ cast<MemrefValue>(read.getSource()),
+ cast<MemrefValue>(write.getSource())) &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
More information about the Mlir-commits
mailing list