[Mlir-commits] [mlir] [mlir][vector] Add FoldTransferReadAfterTransferWrite. (PR #196608)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 12 06:50:48 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

<details>
<summary>Changes</summary>

Adds a canonicalization pattern for folding:

```
transfer_read(transfer_write(valToStore, original, wMask), rMask, rPad)
----------------------------------------------------------------------
select(rMask, select(wMask, valToStore, original), broadcast(rPad))
```

when `not(readOp.hasOutOfBoundsDims() && writeOp.hasOutOfBoundsDims())`.

When only one op has oob dims, then we can take advantage of undefined behaviour to enable the fold:

Case 1.1: w_ib = false, r_ib = true, position is actually in_bounds
We write val, we read val, we can fold RAW to val.
Case 1.2: w_ib = false, r_ib = true, position is NOT in_bounds
We skip write, read says it is in_bounds, but that is false, which is UB
therefore we can fold to val.
Case 2.1: w_ib = true, r_ib = false, position is actually in_bounds
We write val, we read val, we can fold RAW to val.
Case 2.2: w_ib = true, r_ib = false, position is NOT in_bounds
UB on the write, therefore we can fold.

Parts of this were assisted by Claude Opus 4.6. 

---

Patch is 20.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/196608.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+94-1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+331-18) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 51be1e4431e70..4d31e86f8dcf6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/UB/IR/UBMatchers.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -5625,11 +5626,103 @@ struct TransferReadAfterWriteToBroadcast
     return success();
   }
 };
+
+/// Folds a transfer_read that reads from the result of a transfer_write on
+/// the same region (Read-After-Write) into arithmetic on the written value,
+/// the original tensor, the masks, and the read's padding.
+///
+/// The general semantics are:
+///
+///   written_tensor[i] = wMask[i] ? valToStore[i] : original[i]
+///   result[i]         = rMask[i] ? written_tensor[i] : rPad
+///
+/// Which gives:
+///   result = select(rMask, select(wMask, valToStore, original),
+///   broadcast(rPad))
+///
+/// Special cases avoid emitting unnecessary IR:
+///   - No wMask (unmasked write): wMask is implicitly all-true, inner select
+///     collapses to valToStore.
+///   - No rMask (unmasked read): rMask is implicitly all-true, outer select
+///     collapses away.
+///   - wMask == rMask: the original tensor is never needed (anywhere rMask is
+///     true, wMask is also true), so the inner select collapses to valToStore.
+///
+/// After bufferization, this generally removes the need for materializing the
+/// write to memory.
+struct FoldTransferReadAfterTransferWrite
+    : public OpRewritePattern<TransferReadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    if (!readOp.hasPureTensorSemantics())
+      return failure();
+
+    if (readOp->getParentOfType<MaskOp>())
+      return failure();
+
+    auto writeOp =
+        dyn_cast_if_present<TransferWriteOp>(readOp.getBase().getDefiningOp());
+    if (!writeOp || !writeOp.hasPureTensorSemantics())
+      return failure();
+
+    Value valToStore = writeOp.getValueToStore();
+    if (valToStore.getType() != readOp.getType())
+      return failure();
+
+    if ((llvm::any_of(readOp.getIndices(),
+                      [](Value v) { return !isZeroInteger(v); }) ||
+         llvm::any_of(writeOp.getIndices(),
+                      [](Value v) { return !isZeroInteger(v); })) &&
+        (readOp.getIndices() != writeOp.getIndices()))
+      return failure();
+
+    if (!readOp.getPermutationMap().isMinorIdentity() ||
+        !writeOp.getPermutationMap().isMinorIdentity())
+      return failure();
+
+    // We cannot fold when both of them are out of bounds.
+    // If one of them is in bounds but the other one isn't, then
+    // we can take advantage of undefined behaviour to fold.
+    if (readOp.hasOutOfBoundsDim() && writeOp.hasOutOfBoundsDim())
+      return failure();
+
+    TypedValue<VectorType> wMask = writeOp.getMask();
+    TypedValue<VectorType> rMask = readOp.getMask();
+
+    // Build the inner value: select(wMask, valToStore, original).
+    // When wMask is absent (unmasked write) or wMask == rMask (original is
+    // never accessed), this simplifies to just valToStore.
+    Value inner = valToStore;
+    bool needsOriginal = wMask && wMask != rMask;
+    if (needsOriginal) {
+      Value originalRead = TransferReadOp::create(
+          rewriter, readOp.getLoc(), readOp.getType(), writeOp.getBase(),
+          readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
+          /*mask=*/Value(), readOp.getInBoundsAttr());
+      inner = arith::SelectOp::create(rewriter, readOp.getLoc(), wMask,
+                                      valToStore, originalRead);
+    }
+
+    if (!rMask) {
+      rewriter.replaceOp(readOp, inner);
+      return success();
+    }
+
+    Value rPad = readOp.getPadding();
+    Value padVal = BroadcastOp::create(rewriter, rPad.getLoc(),
+                                       valToStore.getType(), rPad);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(readOp, rMask, inner, padVal);
+    return success();
+  }
+};
 } // namespace
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<TransferReadAfterWriteToBroadcast>(context);
+  results.add<TransferReadAfterWriteToBroadcast,
+              FoldTransferReadAfterTransferWrite>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6aa92ab79a0dd..3a65a6a70928e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1991,24 +1991,6 @@ func.func @negative_store_to_load_tensor_memref(
 
 // -----
 
-// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
-//   CHECK-NOT:   vector.broadcast
-//   CHECK-NOT:   vector.transpose
-//       CHECK:   vector.transfer_write
-//       CHECK:   vector.transfer_read
-func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
-  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
-  %c0 = arith.constant 0 : index
-  %cf0 = arith.constant 0.0 : f32
-  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
-    vector<4x2xf32>, tensor<?x?xf32>
-  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
-    tensor<?x?xf32>, vector<4x2xf32>
-  return %0 : vector<4x2xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
 //   CHECK-NOT:   vector.broadcast
 //   CHECK-NOT:   vector.transpose
@@ -2106,6 +2088,337 @@ func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
 
 // -----
 
+// Both write and read are masked with the same mask: the original tensor is
+// never needed, so the inner select collapses. Result is
+// select(mask, val, broadcast(pad)).
+// CHECK-LABEL: func @fold_transfer_raw_both_masked
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST_1:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-DAG:     %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[CST_1]], %[[CST_0]]
+// CHECK:         return %[[SEL]]
+func.func @fold_transfer_raw_both_masked(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Masked write, unmasked read: replace with select(wMask, val, read(original)).
+// CHECK-LABEL: func @fold_transfer_raw_masked_write_unmasked_read
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK:         %[[READ:.*]] = vector.transfer_read %[[T]]{{.*}}, %[[CST]]
+// CHECK-SAME:      {in_bounds = [true]}
+// CHECK-SAME:      : tensor<128xf16>, vector<128xf16>
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[READ]]
+// CHECK:         return %[[SEL]]
+func.func @fold_transfer_raw_masked_write_unmasked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Both unmasked: the read is directly replaced by the written value.
+// CHECK-LABEL: func @fold_transfer_raw_both_unmasked
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         return %[[VAL]]
+func.func @fold_transfer_raw_both_unmasked(%t: tensor<128xf16>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Unmasked write, masked read: result is select(rMask, val, broadcast(pad)).
+// CHECK-LABEL: func @fold_transfer_raw_unmasked_write_masked_read
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-DAG:     %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[RES:.+]] = arith.select %[[MASK]], %[[VAL]], %[[PAD]]
+// CHECK-NEXT:    return %[[RES]]
+func.func @fold_transfer_raw_unmasked_write_masked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Negative test: memref semantics — pattern must not fire.
+// CHECK-LABEL: func @negative_fold_transfer_raw_memref
+// CHECK:         vector.transfer_write
+// CHECK:         vector.transfer_read
+func.func @negative_fold_transfer_raw_memref(%m: memref<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  vector.transfer_write %val, %m[%c0], %mask {in_bounds = [true]}
+     : vector<128xf16>, memref<128xf16>
+  %r = vector.transfer_read %m[%c0], %cst, %mask {in_bounds = [true]}
+     : memref<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Negative test: type mismatch between written and read vectors.
+// CHECK-LABEL: func @negative_fold_transfer_raw_type_mismatch
+// CHECK:         vector.transfer_write
+// CHECK:         vector.transfer_read
+func.func @negative_fold_transfer_raw_type_mismatch(%t: tensor<128xf16>) -> vector<64xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<64xf16>
+  return %r : vector<64xf16>
+}
+
+// -----
+
+// Negative test: different non-zero indices between write and read.
+// CHECK-LABEL: func @negative_fold_transfer_raw_different_indices
+// CHECK:         vector.transfer_write
+// CHECK:         vector.transfer_read
+func.func @negative_fold_transfer_raw_different_indices(
+    %t: tensor<256xf16>, %i: index, %j: index) -> vector<128xf16> {
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%i] {in_bounds = [true]}
+     : vector<128xf16>, tensor<256xf16>
+  %r = vector.transfer_read %w[%j], %cst {in_bounds = [true]}
+     : tensor<256xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// -----
+
+// Write has OOB dim, read claims in-bounds, same mask: fold is valid because
+// the read's in_bounds=true makes an actual OOB access UB.
+// CHECK-LABEL: func @fold_transfer_raw_oob_write_same_mask
+// CHECK-SAME:    %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<1x32x16xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[PAD]]
+// CHECK:         return %[[SEL]]
+func.func @fold_transfer_raw_oob_write_same_mask(
+    %val: vector<1x32x16xf16>, %sz: index,
+    %mask: vector<1x32x16xi1>) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0], %mask
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad, %mask
+     {in_bounds = [true, true, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Negative: both write and read have OOB dims, no masks — fold must NOT fire.
+// CHECK-LABEL: func @negative_fold_transfer_raw_oob_both_no_masks
+// CHECK:         vector.transfer_write
+// CHECK:         vector.transfer_read
+func.func @negative_fold_transfer_raw_oob_both_no_masks(
+    %val: vector<1x32x16xf16>, %sz: index) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0]
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad
+     {in_bounds = [true, false, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Negative: both write and read have OOB dims with same mask — fold must NOT
+// fire.
+// CHECK-LABEL: func @negative_fold_transfer_raw_oob_both_same_mask
+// CHECK:         vector.transfer_write
+// CHECK:         vector.transfer_read
+func.func @negative_fold_transfer_raw_oob_both_same_mask(
+    %val: vector<1x32x16xf16>, %sz: index,
+    %mask: vector<1x32x16xi1>) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0], %mask
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad, %mask
+     {in_bounds = [true, false, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Only read has OOB dim, write claims in-bounds, same mask: fold is valid
+// because the write's in_bounds=true makes an actual OOB access UB.
+// CHECK-LABEL: func @fold_transfer_raw_oob_read_only
+// CHECK-SAME:    %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<1x32x16xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[PAD]]
+// CHECK:         return %[[SEL]]
+func.func @fold_transfer_raw_oob_read_only(
+    %val: vector<1x32x16xf16>, %sz: index,
+    %mask: vector<1x32x16xi1>) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0], %mask
+     {in_bounds = [true, true, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad, %mask
+     {in_bounds = [true, false, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Write has OOB dim, read claims in-bounds, no masks: fold is valid.
+// CHECK-LABEL: func @fold_transfer_raw_oob_write_no_masks
+// CHECK-SAME:    %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         return %[[VAL]]
+func.func @fold_transfer_raw_oob_write_no_masks(
+    %val: vector<1x32x16xf16>, %sz: index) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0]
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad
+     {in_bounds = [true, true, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Write has OOB dim, read claims in-bounds, different masks: fold is valid.
+// The inner select reads from the original tensor (tensor.empty), producing
+// select(wMask, val, read(tensor.empty)). The outer select then applies rMask.
+// CHECK-LABEL: func @fold_transfer_raw_oob_write_different_masks
+// CHECK-SAME:    %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[SZ:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[WMASK:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[RMASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[PAD_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<1x32x16xf16>
+// CHECK-DAG:     %[[PAD:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK:         %[[EMPTY:.*]] = tensor.empty(%[[SZ]])
+// CHECK:         %[[READ:.*]] = vector.transfer_read %[[EMPTY]]{{.*}}, %[[PAD]]
+// CHECK:         %[[INNER:.*]] = arith.select %[[WMASK]], %[[VAL]], %[[READ]]
+// CHECK:         %[[OUTER:.*]] = arith.select %[[RMASK]], %[[INNER]], %[[PAD_VEC]]
+// CHECK:         return %[[OUTER]]
+func.func @fold_transfer_raw_oob_write_different_masks(
+    %val: vector<1x32x16xf16>, %sz: index,
+    %wmask: vector<1x32x16xi1>,
+    %rmask: vector<1x32x16xi1>) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0], %wmask
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad, %rmask
+     {in_bounds = [true, true, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Write has OOB dim, only write is masked, read claims in-bounds: fold is
+// valid. The inner select reads from the original tensor (tensor.empty),
+// producing select(wMask, val, read(tensor.empty)). No rMask, so the result
+// is the inner select.
+// CHECK-LABEL: func @fold_transfer_raw_oob_write_only_write_masked
+// CHECK-SAME:    %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[SZ:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK:         %[[EMPTY:.*]] = tensor.empty(%[[SZ]])
+// CHECK:         %[[READ:.*]] = vector.transfer_read %[[EMPTY]]{{.*}}, %[[PAD]]
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[READ]]
+// CHECK:         return %[[SEL]]
+func.func @fold_transfer_raw_oob_write_only_write_masked(
+    %val: vector<1x32x16xf16>, %sz: index,
+    %mask: vector<1x32x16xi1>) -> vector<1x32x16xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<1x?x16xf16>
+  %w = vector.transfer_write %val, %e[%c0, %c0, %c0], %mask
+     {in_bounds = [true, false, true]} : vector<1x32x16xf16>, tensor<1x?x16xf16>
+  %r = vector.transfer_read %w[%c0, %c0, %c0], %pad
+     {in_bounds = [true, true, true]} : tensor<1x?x16xf16>, vector<1x32x16xf16>
+  return %r : vector<1x32x16xf16>
+}
+
+// -----
+
+// Negative test: transfer_read is inside a vector.mask — the pattern must not
+// fold because the external mask is not visible through getMask().
+// CHECK-LABEL: func @negative_fold_transfer_raw_vector_mask
+// CHECK:         vector.transfer_write
+// CHECK:         vector.mask
+// CHECK:         vec...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/196608


More information about the Mlir-commits mailing list