[Mlir-commits] [mlir] [mlir][vector] Add FoldTransferReadAfterTransferWrite. (PR #196608)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri May 8 11:47:21 PDT 2026
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/196608
Adds a canonicalization pattern for folding RAW.
>From 31fc4bc2138125fdad0fa33ba6f81fd35cdf5de4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 8 May 2026 14:44:18 -0400
Subject: [PATCH] [mlir][vector] Add FoldTransferReadAfterTransferWrite.
Adds a canonicalization pattern for folding RAW.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 86 ++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 157 ++++++++++++++++++---
2 files changed, 224 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 51be1e4431e70..6dd3cb950d551 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBMatchers.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -5625,11 +5626,94 @@ struct TransferReadAfterWriteToBroadcast
return success();
}
};
+
+/// Folds a transfer_read that reads from the result of a transfer_write on
+/// the same region (Read-After-Write) into arithmetic on the written value,
+/// the original tensor, the masks, and the read's padding.
+///
+/// The general semantics are:
+///
+/// written_tensor[i] = wMask[i] ? valToStore[i] : original[i]
+/// result[i] = rMask[i] ? written_tensor[i] : rPad
+///
+/// Which gives:
+/// result = select(rMask, select(wMask, valToStore, original),
+/// broadcast(rPad))
+///
+/// Special cases avoid emitting unnecessary IR:
+/// - No wMask (unmasked write): wMask is implicitly all-true, inner select
+/// collapses to valToStore.
+/// - No rMask (unmasked read): rMask is implicitly all-true, outer select
+/// collapses away.
+/// - wMask == rMask: the original tensor is never needed (anywhere rMask is
+/// true, wMask is also true), so the inner select collapses to valToStore.
+///
+/// After bufferization, this generally removes the need for materializing the
+/// write to memory.
+struct FoldTransferReadAfterTransferWrite
+ : public OpRewritePattern<TransferReadOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+ if (!readOp.hasPureTensorSemantics())
+ return failure();
+
+ auto writeOp =
+ dyn_cast_if_present<TransferWriteOp>(readOp.getBase().getDefiningOp());
+ if (!writeOp || !writeOp.hasPureTensorSemantics())
+ return failure();
+
+ Value valToStore = writeOp.getValueToStore();
+ if (valToStore.getType() != readOp.getType())
+ return failure();
+
+ if ((llvm::any_of(readOp.getIndices(),
+ [](Value v) { return !isZeroInteger(v); }) ||
+ llvm::any_of(writeOp.getIndices(),
+ [](Value v) { return !isZeroInteger(v); })) &&
+ (readOp.getIndices() != writeOp.getIndices()))
+ return failure();
+
+ if (!readOp.getPermutationMap().isMinorIdentity() ||
+ !writeOp.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ TypedValue<VectorType> wMask = writeOp.getMask();
+ TypedValue<VectorType> rMask = readOp.getMask();
+
+ // Build the inner value: select(wMask, valToStore, original).
+ // When wMask is absent (unmasked write) or wMask == rMask (original is
+ // never accessed), this simplifies to just valToStore.
+ Value inner = valToStore;
+ bool needsOriginal = wMask && wMask != rMask;
+ if (needsOriginal) {
+ Value originalRead = TransferReadOp::create(
+ rewriter, readOp.getLoc(), readOp.getType(), writeOp.getBase(),
+ readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
+ /*mask=*/Value(), readOp.getInBoundsAttr());
+ inner = arith::SelectOp::create(rewriter, readOp.getLoc(), wMask,
+ valToStore, originalRead);
+ }
+
+ if (!rMask) {
+ rewriter.replaceOp(readOp, inner);
+ return success();
+ }
+
+ Value rPad = readOp.getPadding();
+ Value padVal = BroadcastOp::create(rewriter, rPad.getLoc(),
+ valToStore.getType(), rPad);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(readOp, rMask, inner, padVal);
+ return success();
+ }
+};
} // namespace
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransferReadAfterWriteToBroadcast>(context);
+ results.add<TransferReadAfterWriteToBroadcast,
+ FoldTransferReadAfterTransferWrite>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6aa92ab79a0dd..daae532f9fd50 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1991,24 +1991,6 @@ func.func @negative_store_to_load_tensor_memref(
// -----
-// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
-// CHECK-NOT: vector.broadcast
-// CHECK-NOT: vector.transpose
-// CHECK: vector.transfer_write
-// CHECK: vector.transfer_read
-func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
- %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
- %c0 = arith.constant 0 : index
- %cf0 = arith.constant 0.0 : f32
- %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
- vector<4x2xf32>, tensor<?x?xf32>
- %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
- tensor<?x?xf32>, vector<4x2xf32>
- return %0 : vector<4x2xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
@@ -2106,6 +2088,145 @@ func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
// -----
+// Both write and read are masked with the same mask: the original tensor is
+// never needed, so the inner select collapses. Result is
+// select(mask, val, broadcast(pad)).
+// CHECK-LABEL: func @fold_transfer_raw_both_masked
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[CST_1]], %[[CST_0]]
+// CHECK: return %[[SEL]]
+func.func @fold_transfer_raw_both_masked(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
+// Masked write, unmasked read: replace with select(wMask, val, read(original)).
+// CHECK-LABEL: func @fold_transfer_raw_masked_write_unmasked_read
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[T]]{{.*}}, %[[CST]] {in_bounds = [true]}
+// CHECK-SAME: : tensor<128xf16>, vector<128xf16>
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[READ]]
+// CHECK: return %[[SEL]]
+func.func @fold_transfer_raw_masked_write_unmasked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
+// Both unmasked: the read is directly replaced by the written value.
+// CHECK-LABEL: func @fold_transfer_raw_both_unmasked
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: return %[[VAL]]
+func.func @fold_transfer_raw_both_unmasked(%t: tensor<128xf16>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
+// Unmasked write, masked read: result is select(rMask, val, broadcast(pad)).
+// CHECK-LABEL: func @fold_transfer_raw_unmasked_write_masked_read
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-DAG: %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VAL]], %[[PAD]]
+// CHECK: return %[[RES]]
+func.func @fold_transfer_raw_unmasked_write_masked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
+// Negative test: memref semantics — pattern must not fire.
+// CHECK-LABEL: func @negative_fold_transfer_raw_memref
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_fold_transfer_raw_memref(%m: memref<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ vector.transfer_write %val, %m[%c0], %mask {in_bounds = [true]}
+ : vector<128xf16>, memref<128xf16>
+ %r = vector.transfer_read %m[%c0], %cst, %mask {in_bounds = [true]}
+ : memref<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
+// Negative test: type mismatch between written and read vectors.
+// CHECK-LABEL: func @negative_fold_transfer_raw_type_mismatch
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_fold_transfer_raw_type_mismatch(%t: tensor<128xf16>) -> vector<64xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<64xf16>
+ return %r : vector<64xf16>
+}
+
+// -----
+
+// Negative test: different non-zero indices between write and read.
+// CHECK-LABEL: func @negative_fold_transfer_raw_different_indices
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @negative_fold_transfer_raw_different_indices(
+ %t: tensor<256xf16>, %i: index, %j: index) -> vector<128xf16> {
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%i] {in_bounds = [true]}
+ : vector<128xf16>, tensor<256xf16>
+ %r = vector.transfer_read %w[%j], %cst {in_bounds = [true]}
+ : tensor<256xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list