[Mlir-commits] [mlir] [mlir][Vector] Tighten up application conditions in TransferReadAfter… (PR #143869)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 12 07:48:01 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/143869
>From 0dd282f5f618d51bc465d482a3f478c170793247 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 11:27:50 +0200
Subject: [PATCH] [mlir][Vector] Tighten up application conditions in
TransferReadAfterWriteToBroadcast
The pattern would previously apply in spurious cases and generate incorrect IR.
In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking.
The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore.
The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:
- either into the transfer, or
- outside of the transfer
Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 30 ++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 108 ++++++++++++++++++---
2 files changed, 117 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..2a2357319bd23 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp.hasOutOfBoundsDim() ||
- !llvm::isa<RankedTensorType>(readOp.getShapedType()))
- return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
+ // Bail if we need an alias analysis.
+ if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
+ return failure();
+ // Bail if we need a bounds analysis.
+ if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+ return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ // This pattern should only catch the broadcast case, the non-broadcast case
+ // should be done separately to keep application conditions clean and
+ // separate.
+ AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+ AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+ bool bcast = !readMap.getBroadcastDims().empty() ||
+ !writeMap.getBroadcastDims().empty();
+ if (!bcast)
+ return failure();
+ // At this point, we know we have a bcast.
+ // Bail in the masked case (too complex atm and needed to properly account
+ // for padding).
+ if (readOp.getMask() || defWrite.getMask())
+ return failure();
+ // If indices are not the same a shift may be required, bail.
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
+
Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..6691cb52acdc0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// -----
// Negative test where the extract is not a subset of the element inserted.
-// CHECK-LABEL: extract_strided_fold_negative
+// CHECK-LABEL: negative_extract_strided_fold
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
-func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
-> (vector<6x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
-// CHECK-LABEL: fold_extract_broadcast_negative
+// CHECK-LABEL: negative_fold_extract_broadcast
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
-func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
+func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
// -----
-// CHECK-LABEL: fold_extract_shapecast_negative
+// CHECK-LABEL: negative_fold_extract_shapecast
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
+func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
// -----
-// CHECK-LABEL: func @store_after_load_tensor_negative
+// CHECK-LABEL: func @negative_store_after_load_tensor
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: return
-func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
// -----
-// CHECK-LABEL: func @store_to_load_negative_tensor
+// CHECK-LABEL: func @negative_store_to_load_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: %[[V:.*]] = vector.transfer_read
// CHECK: return %[[V]] : vector<1x4xf32>
-func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
+func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
@@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
+// CHECK-LABEL: func @negative_store_to_load_tensor_memref
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_store_to_load_tensor_memref(
+ %arg0 : tensor<?x?xf32>,
+ %arg1 : memref<?x?xf32>,
+ %v0 : vector<4x2xf32>
+ ) -> vector<4x2xf32>
+{
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
+ vector<4x2xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+ tensor<?x?xf32>, vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+ tensor<?x?xf32>, vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+ tensor<?x?xf32>, vector<4x2x6xf32>
+ return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_store_to_load_tensor_broadcast_masked(
+ %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
+ -> vector<4x2x6xf32>
+{
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+ tensor<?x?xf32>, vector<4x2x6xf32>
+ return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1604,7 +1684,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
// -----
-// CHECK-LABEL: func @dead_store_tensor_negative
+// CHECK-LABEL: func @negative_dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: vector.transfer_write
@@ -1612,7 +1692,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
// CHECK: vector.transfer_read
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
-func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
+func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
@@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
// -----
-// CHECK-LABEL: extract_insert_negative
+// CHECK-LABEL: negative_extract_insert
// CHECK: vector.insert_strided_slice
// CHECK: vector.extract
-func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
+func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
-> vector<16xf32> {
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
: vector<2x15xf32> into vector<12x8x16xf32>
More information about the Mlir-commits
mailing list