[Mlir-commits] [mlir] [mlir][Vector] Tighten up application conditions in TransferReadAfter… (PR #143869)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 12 03:46:30 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/143869
>From 50589c05f71e96b72c98435e86806e7071f3ab49 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 11:27:50 +0200
Subject: [PATCH 1/2] [mlir][Vector] Tighten up application conditions in
TransferReadAfterWriteToBroadcast
The pattern would previously apply in spurious cases and generate incorrect IR.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 29 +++++++++++++----
mlir/test/Dialect/Vector/canonicalize.mlir | 37 ++++++++++++++++++++++
2 files changed, 59 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp.hasOutOfBoundsDim() ||
- !llvm::isa<RankedTensorType>(readOp.getShapedType()))
- return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
+ // Bail if we need an alias analysis.
+ if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+ return failure();
+ // Bail if we need a bounds analysis.
+ if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+ return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ // This pattern should only catch the broadcast case, the non-broadcast case
+ // should be done separately to keep application conditions clean and
+ // separate.
+ AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+ AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+ bool bcast = !readMap.getBroadcastDims().empty() ||
+ !writeMap.getBroadcastDims().empty();
+ if (!bcast)
+ return failure();
+ // At this point, we know we have a bcast.
+ // The masked case is too complext atm, bail.
+ if (readOp.getMask() || defWrite.getMask())
+ return failure();
+ // If indices are not the same a shift may be required, bail.
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
+
Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+ tensor<?x?xf32>, vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+ tensor<?x?xf32>, vector<4x2x6xf32>
+ return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
>From b91e48a97c613938d6963c30bf27ead6a8f7c2c1 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 12:46:22 +0200
Subject: [PATCH 2/2] Update mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Co-authored-by: Fabian Mora <fmora.dev at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 32e9fcf6ed044..1519f7210be77 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4672,7 +4672,7 @@ struct TransferReadAfterWriteToBroadcast
if (!defWrite)
return failure();
// Bail if we need an alias analysis.
- if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+ if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
return failure();
// Bail if we need a bounds analysis.
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
More information about the Mlir-commits
mailing list