[Mlir-commits] [mlir] [MLIR][AMDGPU] Redirect transfer read to masked load lowering (PR #146705)

Zhuoran Yin llvmlistbot at llvm.org
Wed Jul 2 07:33:04 PDT 2025


https://github.com/jerryyin created https://github.com/llvm/llvm-project/pull/146705

This PR reworks https://github.com/llvm/llvm-project/pull/131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.

>From c86b23a547512ee27c540bcb711823e719122d6c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 1 Jul 2025 15:32:32 +0000
Subject: [PATCH] Redirect transfer read to masked load lowering

Signed-off-by: jerryyin <zhuoryin at amd.com>
---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.h   |   6 +-
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  |   4 +-
 .../Dialect/AMDGPU/Transforms/CMakeLists.txt  |   2 +-
 .../AMDGPU/Transforms/MaskedloadToLoad.cpp    | 167 ++++++++++++
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 239 ------------------
 ...d-to-load.mlir => maskedload-to-load.mlir} |  78 ++----
 6 files changed, 199 insertions(+), 297 deletions(-)
 create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
 delete mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
 rename mlir/test/Dialect/AMDGPU/{transfer-read-to-load.mlir => maskedload-to-load.mlir} (56%)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
 void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit = 1);
 
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
   ];
 }
 
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
-  let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+  let summary = "Lower the operations from the vector maskedload to vector load";
   let description = [{
     This pass creates a transfer read op lowering optimization. The lowering
     will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
   ResolveStridedMetadata.cpp
-  TransferReadToLoad.cpp
+  MaskedloadToLoad.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+                                           vector::MaskedLoadOp maskedOp) {
+  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+      amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+  return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::MaskedLoadOp maskedOp) {
+  VectorType vectorType = maskedOp.getVectorType();
+  Value load = builder.create<vector::LoadOp>(
+      loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(
+      loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+  return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+    "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+                                PatternRewriter &rewriter) const override {
+    if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+      return failure();
+
+    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+      return failure();
+    }
+
+    Location loc = maskedOp.getLoc();
+    Value src = maskedOp.getBase();
+
+    VectorType vectorType = maskedOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    SmallVector<OpFoldResult> strides =
+        stridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedIndices;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, offset, sizes,
+                                                 strides, indices);
+
+    // delta = bufferSize - linearizedOffset
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    Value totalSize = getValueOrCreateConstantIndexOp(
+        rewriter, loc, linearizedInfo.linearizedSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+    // 2) check if (detla % elements_per_word != 0)
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, llvm::divideCeil(32, elementBitWidth));
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of maskedload default lowering only it is both
+    // out-of-bounds and not word aligned. The fallback ensures correct results
+    // when loading at the boundary of the buffer since buffer load returns
+    // inconsistent zeros for the whole word when boundary is crossed.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*maskedOp.getOperation());
+      read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(maskedOp, ifOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+    : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuMaskedloadToLoadPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
-    PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
-    bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
-  if (!xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
-  // Permutations are handled by VectorToSCF or
-  // populateVectorTransferPermutationMapLoweringPatterns.
-  // We let the 0-d corner case pass-through as it is supported.
-  SmallVector<unsigned> broadcastedDims;
-  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
-          &broadcastedDims))
-    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
-  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
-  if (!memRefType)
-    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
-  Attribute addrSpace = memRefType.getMemorySpace();
-  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(xferOp, "no address space");
-
-  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-      amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
-  // Non-unit strides are handled by VectorToSCF.
-  if (!memRefType.isLastDimUnitStride())
-    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
-  if (memRefType.getElementTypeBitWidth() < 8)
-    return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
-  // If there is broadcasting involved then we first load the unbroadcasted
-  // vector, and then broadcast it with `vector.broadcast`.
-  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
-  for (unsigned i : broadcastedDims)
-    unbroadcastedVectorShape[i] = 1;
-  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
-      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
-  requiresBroadcasting = !broadcastedDims.empty();
-
-  // `vector.load` supports vector types as memref's elements only when the
-  // resulting vector type is the same as the element type.
-  auto memrefElTy = memRefType.getElementType();
-  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
-    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
-  // Otherwise, element types of the memref and the vector must match.
-  if (!isa<VectorType>(memrefElTy) &&
-      memrefElTy != xferOp.getVectorType().getElementType())
-    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
-  // Out-of-bounds dims are handled by MaterializeTransferMask.
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
-  if (xferOp.getVectorType().getRank() != 1)
-    // vector.maskedload operates on 1-D vectors.
-    return rewriter.notifyMatchFailure(
-        xferOp, "vector type is not rank 1, can't create masked load, needs "
-                "VectorToSCF");
-
-  return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
-                                           vector::TransferReadOp readOp,
-                                           bool requiresBroadcasting,
-                                           VectorType unbroadcastedVectorType) {
-  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                               readOp.getPadding());
-  Value load = builder.create<vector::LoadOp>(
-      loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
-  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                              readOp.getMask(), load, fill);
-  // Insert a broadcasting op if required.
-  if (requiresBroadcasting) {
-    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
-  }
-  return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
-    "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr(kTransferReadNeedsMask))
-      return failure();
-
-    bool requiresBroadcasting = false;
-    VectorType unbroadcastedVectorType;
-    if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
-                                     unbroadcastedVectorType))) {
-      return failure();
-    }
-
-    Location loc = readOp.getLoc();
-    Value src = readOp.getBase();
-
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
-    SmallVector<OpFoldResult> indices = readOp.getIndices();
-
-    auto stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
-    SmallVector<OpFoldResult> strides =
-        stridedMetadata.getConstifiedMixedStrides();
-    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
-    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
-    memref::LinearizedMemRefInfo linearizedInfo;
-    OpFoldResult linearizedIndices;
-    std::tie(linearizedInfo, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
-                                                 elementBitWidth, offset, sizes,
-                                                 strides, indices);
-
-    // delta = bufferSize - linearizedOffset
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value linearIndex =
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
-    Value totalSize = getValueOrCreateConstantIndexOp(
-        rewriter, loc, linearizedInfo.linearizedSize);
-    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
-    // 1) check if delta < vectorSize
-    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
-    // 2) check if (detla % elements_per_word != 0)
-    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::divideCeil(32, elementBitWidth));
-    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ne,
-        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
-        rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
-    // We take the fallback of transfer_read default lowering only it is both
-    // out-of-bounds and not word aligned. The fallback ensures correct results
-    // when loading at the boundary of the buffer since buffer load returns
-    // inconsistent zeros for the whole word when boundary is crossed.
-    Value ifCondition =
-        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
-    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
-      Value readResult = read->getResult(0);
-      builder.create<scf::YieldOp>(loc, readResult);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(
-          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
-    rewriter.replaceOp(readOp, ifOp);
-
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
-    : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
-          AmdgpuTransferReadToLoadPass> {
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateAmdgpuTransferReadToLoadPatterns(patterns);
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-      return signalPassFailure();
-    }
-  }
-};
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
similarity index 56%
rename from mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
rename to mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
index 20999af10553e..febe46bf7a759 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
@@ -1,17 +1,17 @@
-// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --amdgpu-maskedload-to-load --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32>
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
+  %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK: %[[IF:.*]] = scf.if
-// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]]
 
 // CHECK: } else {
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
@@ -25,10 +25,10 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
 // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16(
 // CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>,
 // CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
-// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>)
-func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> {
-  %cf0 = arith.constant 0.0 : f16
-  %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf16>
+// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>
+// CHECK-SAME: %[[ARG4:.+]]: vector<4xf16>
+func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>, %passthru : vector<4xf16>) -> vector<4xf16> {
+  %res = vector.maskedload %mem[%idx0, %idx1], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
   return %res : vector<4xf16>
 }
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0
@@ -45,7 +45,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 
 // CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
 // CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
-// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: } else {
 // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: return %[[IF]] : vector<4xf16>
@@ -58,13 +58,11 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
 // CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> {
-  %cf0 = arith.constant 0 : i8
-  %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
+// CHECK-SAME: %[[ARG4:.*]]: vector<4xi8>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>, %passthru : vector<4xi8>) -> vector<4xi8> {
+  %res = vector.maskedload %mem[%idx0, %idx1], %mask, %passthru : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xi8> into vector<4xi8>
   return %res : vector<4xi8>
 }
-
-// CHECK:     %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
 // CHECK:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK:     %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
@@ -79,13 +77,12 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32>
+func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
+  %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
-// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[ARG3]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----
@@ -94,49 +91,26 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space<workgroup>>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32>
+func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
+  %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
-// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[ARG3]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----
 
-// CHECK-LABEL: func @transfer_broadcasting(
-// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
-// CHECK-SAME: %[[ARG1:.*]]: index
-// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
-#broadcast_1d = affine_map<(d0, d1) -> (0)>
-func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
-    {in_bounds = [true], permutation_map = #broadcast_1d}
-      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
-  return %res : vector<4xf32>
-}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK: %[[IF:.*]] = scf.if
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
-// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-
-// -----
-
 // CHECK-LABEL: func @transfer_scalar(
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
-func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
-    {in_bounds = [true]}
-      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
+// CHECK-SAME: %[[ARG3:.*]]: vector<1xf32>
+func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>, %passthru : vector<1xf32>) -> vector<1xf32> {
+  %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru
+      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xi1>, vector<1xf32> into vector<1xf32>
   return %res : vector<1xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
 // CHECK: %[[IF:.*]] = scf.if
 // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]



More information about the Mlir-commits mailing list