[Mlir-commits] [mlir] [mlir][Vector] Tighten up application conditions in TransferReadAfter… (PR #143869)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 12 05:58:17 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/143869

>From 0210067bd7a4bd47fa18cb8107f29c23dbd6a523 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 11:27:50 +0200
Subject: [PATCH] [mlir][Vector] Tighten up application conditions in
 TransferReadAfterWriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.

In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 30 ++++++++++++++----
 mlir/test/Dialect/Vector/canonicalize.mlir | 37 ++++++++++++++++++++++
 2 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..993ad829bc097 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp.hasOutOfBoundsDim() ||
-        !llvm::isa<RankedTensorType>(readOp.getShapedType()))
-      return failure();
     auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
+    // Bail if we need an alias analysis.
+    if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
+      return failure();
+    // Bail if we need a bounds analysis.
+    if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+      return failure();
     // TODO: If the written transfer chunk is a superset of the read transfer
     // chunk we could do an extract_strided_slice.
     if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
       return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    // This pattern should only catch the broadcast case, the non-broadcast case
+    // should be done separately to keep application conditions clean and
+    // separate.
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    bool bcast = !readMap.getBroadcastDims().empty() ||
+                 !writeMap.getBroadcastDims().empty();
+    if (!bcast)
+      return failure();
+    // At this point, we know we have a bcast.
+    // Bail in the masked case (too complext atm and needed to properly account
+    // for padding).
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
+
     Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>



More information about the Mlir-commits mailing list