[Mlir-commits] [mlir] e4de74b - [mlir][Vector] Tighten up application conditions in TransferReadAfter… (#143869)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 12 08:11:10 PDT 2025


Author: Nicolas Vasilache
Date: 2025-06-12T17:11:06+02:00
New Revision: e4de74ba11eadb47cf78afbabffbf2b1a50e7298

URL: https://github.com/llvm/llvm-project/commit/e4de74ba11eadb47cf78afbabffbf2b1a50e7298
DIFF: https://github.com/llvm/llvm-project/commit/e4de74ba11eadb47cf78afbabffbf2b1a50e7298.diff

LOG: [mlir][Vector] Tighten up application conditions in TransferReadAfter… (#143869)

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate
incorrect IR.

In the process, we disable the application of this pattern in the case
where there is no broadcast; this should be handled separately and may
more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this
pattern and arguably could also generate incorrect IR (and was also
untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but
should arguably be removed from the future because creating transposes
as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e.
does your DMA HW allow transpose on the fly or not) but this is beyond
the scope of this PR.

Co-authored-by: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..2a2357319bd23 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp.hasOutOfBoundsDim() ||
-        !llvm::isa<RankedTensorType>(readOp.getShapedType()))
-      return failure();
     auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
+    // Bail if we need an alias analysis.
+    if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
+      return failure();
+    // Bail if we need a bounds analysis.
+    if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+      return failure();
     // TODO: If the written transfer chunk is a superset of the read transfer
     // chunk we could do an extract_strided_slice.
     if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
       return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    // This pattern should only catch the broadcast case, the non-broadcast case
+    // should be done separately to keep application conditions clean and
+    // separate.
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    bool bcast = !readMap.getBroadcastDims().empty() ||
+                 !writeMap.getBroadcastDims().empty();
+    if (!bcast)
+      return failure();
+    // At this point, we know we have a bcast.
+    // Bail in the masked case (too complex atm and needed to properly account
+    // for padding).
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
+
     Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..6691cb52acdc0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
 // -----
 
 // Negative test where the extract is not a subset of the element inserted.
-// CHECK-LABEL: extract_strided_fold_negative
+// CHECK-LABEL: negative_extract_strided_fold
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
 //       CHECK:   %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
 //  CHECK-SAME:     {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
 //  CHECK-SAME:     {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
 //  CHECK-SAME:       : vector<8x16xf32> to vector<6x4xf32>
 //  CHECK-NEXT:   return %[[EXT]] : vector<6x4xf32>
-func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
+func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
   -> (vector<6x4xf32>) {
   %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
     : vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_negative
+// CHECK-LABEL: negative_fold_extract_broadcast
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
 //       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
-func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
+func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
   %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
 
 // -----
 
-// CHECK-LABEL: fold_extract_shapecast_negative
+// CHECK-LABEL: negative_fold_extract_shapecast
 //       CHECK:   %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
 //       CHECK:   %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
 //       CHECK:   return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
+func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
   %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
   return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
 
 // -----
 
-// CHECK-LABEL: func @store_after_load_tensor_negative
+// CHECK-LABEL: func @negative_store_after_load_tensor
 //       CHECK:   vector.transfer_read
 //       CHECK:   vector.transfer_write
 //       CHECK:   return
-func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
 
 // -----
 
-// CHECK-LABEL: func @store_to_load_negative_tensor
+// CHECK-LABEL: func @negative_store_to_load_tensor
 //       CHECK:   vector.transfer_write
 //       CHECK:   vector.transfer_write
 //       CHECK:   %[[V:.*]] = vector.transfer_read
 //       CHECK:   return %[[V]] : vector<1x4xf32>
-func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
+func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
   %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @negative_store_to_load_tensor_memref
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @negative_store_to_load_tensor_memref(
+    %arg0 : tensor<?x?xf32>,
+    %arg1 : memref<?x?xf32>,
+    %v0 : vector<4x2xf32>
+  ) -> vector<4x2xf32> 
+{
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
+    vector<4x2xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @negative_store_to_load_tensor_broadcast_masked(
+    %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
+  -> vector<4x2x6xf32> 
+{
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1604,7 +1684,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
 
 // -----
 
-// CHECK-LABEL: func @dead_store_tensor_negative
+// CHECK-LABEL: func @negative_dead_store_tensor
 //   CHECK-DAG:      %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:      %[[C1:.*]] = arith.constant 1 : index
 //       CHECK:   vector.transfer_write
@@ -1612,7 +1692,7 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
 //       CHECK:   vector.transfer_read
 //       CHECK:   %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
 //       CHECK:   return %[[VTW]] : tensor<4x4xf32>
-func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
+func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
   %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
 
 // -----
 
-// CHECK-LABEL: extract_insert_negative
+// CHECK-LABEL: negative_extract_insert
 //       CHECK: vector.insert_strided_slice
 //       CHECK: vector.extract
-func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
+func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
   -> vector<16xf32> {
   %0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
     : vector<2x15xf32> into vector<12x8x16xf32>


        


More information about the Mlir-commits mailing list