[Mlir-commits] [mlir] [mlir][Vector] Tighten up application conditions in TransferReadAfter… (PR #143869)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 12 03:46:30 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/143869

>From 50589c05f71e96b72c98435e86806e7071f3ab49 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 11:27:50 +0200
Subject: [PATCH 1/2] [mlir][Vector] Tighten up application conditions in
 TransferReadAfterWriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 29 +++++++++++++----
 mlir/test/Dialect/Vector/canonicalize.mlir | 37 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp.hasOutOfBoundsDim() ||
-        !llvm::isa<RankedTensorType>(readOp.getShapedType()))
-      return failure();
     auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
+    // Bail if we need an alias analysis.
+    if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+      return failure();
+    // Bail if we need a bounds analysis.
+    if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+      return failure();
     // TODO: If the written transfer chunk is a superset of the read transfer
     // chunk we could do an extract_strided_slice.
     if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
       return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    // This pattern should only catch the broadcast case, the non-broadcast case
+    // should be done separately to keep application conditions clean and
+    // separate.
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    bool bcast = !readMap.getBroadcastDims().empty() ||
+                 !writeMap.getBroadcastDims().empty();
+    if (!bcast)
+      return failure();
+    // At this point, we know we have a bcast.
+    // The masked case is too complext atm, bail.
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
+
     Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>

>From b91e48a97c613938d6963c30bf27ead6a8f7c2c1 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Thu, 12 Jun 2025 12:46:22 +0200
Subject: [PATCH 2/2] Update mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Co-authored-by: Fabian Mora <fmora.dev at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 32e9fcf6ed044..1519f7210be77 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4672,7 +4672,7 @@ struct TransferReadAfterWriteToBroadcast
     if (!defWrite)
       return failure();
     // Bail if we need an alias analysis.
-    if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+    if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
       return failure();
     // Bail if we need a bounds analysis.
     if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())



More information about the Mlir-commits mailing list