[Mlir-commits] [llvm] [mlir] [MLIR][AMDGPU] Adding Vector to AMDGPU conversion lowering (PR #131803)

Zhuoran Yin llvmlistbot at llvm.org
Thu Mar 20 14:25:52 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/131803

>From 7f6d6efa5a9b3fab973bfc915037577425949c68 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 18 Mar 2025 13:02:00 +0000
Subject: [PATCH 1/5] Adding Vector to AMDGPU conversion lowering

---
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  10 ++
 .../VectorToAMDGPU/VectorToAMDGPU.h           |  24 +++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../Conversion/VectorToAMDGPU/CMakeLists.txt  |  18 +++
 .../VectorToAMDGPU/VectorToAMDGPU.cpp         | 147 ++++++++++++++++++
 .../vector-transfer-read-to-vector-load.mlir  |  68 ++++++++
 7 files changed, 269 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
 create mode 100644 mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
 create mode 100644 mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ccd862f67c068..ed5e8de8787f7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -73,6 +73,7 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..1845d0235183e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1333,6 +1333,16 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
   let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToAMDGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
+  let summary = "Lower the operations from the vector dialect into the AMDGPU "
+                "dialect";
+  let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+} 
+
 //===----------------------------------------------------------------------===//
 // ArmSMEToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
new file mode 100644
index 0000000000000..be96061a23b08
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
@@ -0,0 +1,24 @@
+//===- VectorToAMDGPU.h - Vector to AMDGPU dialect conversion ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+#define MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateVectorToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b6c21440c571c..1e4cbd2be4c96 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -66,6 +66,7 @@ add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
 add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMDGPU)
 add_subdirectory(VectorToArmSME)
 add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..2ad46c26d0a57
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRVectorToAMDGPU
+  VectorToAMDGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMDGPU
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRAMDGPUDialect
+  MLIRVectorDialect
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
new file mode 100644
index 0000000000000..248b84a7fdc98
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -0,0 +1,147 @@
+//===- VectorToAMDGPU.cpp - Vector to AMDGPU dialect conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+/// This pattern supports lowering of:
+/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
+/// `vector.broadcast` if all of the following hold:
+/// - The transfer op is masked.
+/// - The memref is in buffer address space.
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
+/// pass.
+static LogicalResult
+transferPreconditions(PatternRewriter &rewriter,
+                      VectorTransferOpInterface xferOp,
+                      SmallVector<unsigned> &broadcastedDims,
+                      VectorType &unbroadcastedVectorType) {
+  if (!xferOp.getMask())
+    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
+
+  // Permutations are handled by VectorToSCF or
+  // populateVectorTransferPermutationMapLoweringPatterns.
+  // We let the 0-d corner case pass-through as it is supported.
+  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
+          &broadcastedDims))
+    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
+
+  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!addrSpace ||
+      llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+          amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
+
+  // Non-unit strides are handled by VectorToSCF.
+  if (!memRefType.isLastDimUnitStride())
+    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
+
+  // If there is broadcasting involved then we first load the unbroadcasted
+  // vector, and then broadcast it with `vector.broadcast`.
+  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
+  for (unsigned i : broadcastedDims)
+    unbroadcastedVectorShape[i] = 1;
+  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
+      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
+
+  // `vector.load` supports vector types as memref's elements only when the
+  // resulting vector type is the same as the element type.
+  auto memrefElTy = memRefType.getElementType();
+  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
+    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
+
+  // Otherwise, element types of the memref and the vector must match.
+  if (!isa<VectorType>(memrefElTy) &&
+      memrefElTy != xferOp.getVectorType().getElementType())
+    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
+
+  // Out-of-bounds dims are handled by MaterializeTransferMask.
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
+
+  if (xferOp.getVectorType().getRank() != 1)
+    // vector.maskedload operates on 1-D vectors.
+    return rewriter.notifyMatchFailure(
+        xferOp, "vector type is not rank 1, can't create masked load, needs "
+                "VectorToSCF");
+
+  return success();
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+
+    SmallVector<unsigned> broadcastedDims;
+    VectorType unbroadcastedVectorType;
+    if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
+                                     unbroadcastedVectorType))) {
+      return failure();
+    }
+
+    Value fill = rewriter.create<vector::SplatOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
+    Value load = rewriter.create<vector::LoadOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
+        readOp.getIndices());
+    Value res = rewriter.create<arith::SelectOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
+
+    // Insert a broadcasting op if required.
+    if (!broadcastedDims.empty()) {
+      res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
+                                                 readOp.getVectorType(), res);
+    }
+
+    rewriter.replaceOp(readOp, res);
+
+    return success();
+  }
+};
+
+void mlir::populateVectorToAMDGPUConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadLowering>(patterns.getContext());
+}
+
+struct ConvertVectorToAMDGPUPass
+    : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToAMDGPUConversionPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
new file mode 100644
index 0000000000000..30d9814cc0621
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_regular(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_broadcasting(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+#broadcast_1d = affine_map<(d0, d1) -> (0)>
+func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+    {in_bounds = [true], permutation_map = #broadcast_1d}
+      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
+// CHECK: return %[[BROADCAST]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_scalar(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+    {in_bounds = [true]}
+      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
+  return %res : vector<1xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<1xf32>

>From 4610b01665b363c00c14d8a3674e201dd214c435 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 18 Mar 2025 21:00:53 +0000
Subject: [PATCH 2/5] Addressing review feedbacks

---
 .../VectorToAMDGPU/VectorToAMDGPU.cpp         | 48 +++++++++----------
 .../vector-transfer-read-to-vector-load.mlir  | 17 ++++---
 2 files changed, 33 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
index 248b84a7fdc98..569e8f8bb4e3a 100644
--- a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -15,7 +15,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
@@ -36,17 +36,16 @@ using namespace mlir;
 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
 /// Note: those conditions mostly come from TransferReadToVectorLoadLowering
 /// pass.
-static LogicalResult
-transferPreconditions(PatternRewriter &rewriter,
-                      VectorTransferOpInterface xferOp,
-                      SmallVector<unsigned> &broadcastedDims,
-                      VectorType &unbroadcastedVectorType) {
+static LogicalResult transferPreconditions(
+    PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
+    bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
   if (!xferOp.getMask())
     return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
 
   // Permutations are handled by VectorToSCF or
   // populateVectorTransferPermutationMapLoweringPatterns.
   // We let the 0-d corner case pass-through as it is supported.
+  SmallVector<unsigned> broadcastedDims;
   if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
           &broadcastedDims))
     return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
@@ -56,9 +55,8 @@ transferPreconditions(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(xferOp, "not a memref source");
 
   Attribute addrSpace = memRefType.getMemorySpace();
-  if (!addrSpace ||
-      llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-          amdgpu::AddressSpace::FatRawBuffer)
+  if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+                        amdgpu::AddressSpace::FatRawBuffer)
     return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
 
   // Non-unit strides are handled by VectorToSCF.
@@ -73,6 +71,7 @@ transferPreconditions(PatternRewriter &rewriter,
     unbroadcastedVectorShape[i] = 1;
   unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
       unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
+  requiresBroadcasting = !broadcastedDims.empty();
 
   // `vector.load` supports vector types as memref's elements only when the
   // resulting vector type is the same as the element type.
@@ -98,31 +97,31 @@ transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
 
-    SmallVector<unsigned> broadcastedDims;
+    bool requiresBroadcasting = false;
     VectorType unbroadcastedVectorType;
-    if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
+    if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
                                      unbroadcastedVectorType))) {
       return failure();
     }
 
-    Value fill = rewriter.create<vector::SplatOp>(
-        readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
+    Location loc = readOp.getLoc();
+    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getPadding());
     Value load = rewriter.create<vector::LoadOp>(
-        readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
-        readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(
-        readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
+        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                                 readOp.getMask(), load, fill);
 
     // Insert a broadcasting op if required.
-    if (!broadcastedDims.empty()) {
-      res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
-                                                 readOp.getVectorType(), res);
+    if (requiresBroadcasting) {
+      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+                                                 res);
     }
 
     rewriter.replaceOp(readOp, res);
@@ -136,12 +135,11 @@ void mlir::populateVectorToAMDGPUConversionPatterns(
   patterns.add<TransferReadLowering>(patterns.getContext());
 }
 
-struct ConvertVectorToAMDGPUPass
+struct ConvertVectorToAMDGPUPass final
     : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateVectorToAMDGPUConversionPatterns(patterns);
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
-      return signalPassFailure();
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
index 30d9814cc0621..d0a79045c86da 100644
--- a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
+++ b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-amdgpu --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
@@ -9,9 +9,10 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
 // CHECK: return %[[SELECT]] : vector<4xf32>
 
 // -----
@@ -43,9 +44,10 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
       : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
 // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
 // CHECK: return %[[BROADCAST]] : vector<4xf32>
 
@@ -62,7 +64,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
       : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
   return %res : vector<1xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
 // CHECK: return %[[SELECT]] : vector<1xf32>

>From f7ca23b6d3dd559bac386dcb65cde1dfe7ee1d71 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Wed, 19 Mar 2025 13:29:31 +0000
Subject: [PATCH 3/5] Addressing review feedbacks

---
 mlir/include/mlir/Conversion/Passes.td                | 2 +-
 mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp | 6 +++++-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1845d0235183e..d8f00c7109a3d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1340,7 +1340,7 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
 def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
   let summary = "Lower the operations from the vector dialect into the AMDGPU "
                 "dialect";
-  let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+  let dependentDialects = ["vector::VectorDialect"];
 } 
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
index 569e8f8bb4e3a..b923546d7aaa5 100644
--- a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -97,6 +97,8 @@ static LogicalResult transferPreconditions(
   return success();
 }
 
+namespace {
+
 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -130,13 +132,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+} // namespace
+
 void mlir::populateVectorToAMDGPUConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering>(patterns.getContext());
 }
 
 struct ConvertVectorToAMDGPUPass final
-    : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+    : impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateVectorToAMDGPUConversionPatterns(patterns);

>From 2c292f8958649170c00709211f017800453318f2 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Wed, 19 Mar 2025 19:30:07 +0000
Subject: [PATCH 4/5] Adding bazel target for VectorToAMDGPU

---
 .../llvm-project-overlay/mlir/BUILD.bazel     | 22 +++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 95e2788906525..978779ea2bfb8 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4466,6 +4466,7 @@ cc_library(
         ":TosaToTensor",
         ":UBToLLVM",
         ":UBToSPIRV",
+        ":VectorToAMDGPU",
         ":VectorToArmSME",
         ":VectorToGPU",
         ":VectorToLLVM",
@@ -12189,6 +12190,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "VectorToAMDGPU",
+    srcs = glob([
+        "lib/Conversion/VectorToAMDGPU/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/VectorToAMDGPU/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":AMDGPUDialect",
+        ":VectorDialect",
+        ":ConversionPassIncGen",
+        ":IR",
+        ":MemRefDialect",
+        ":Pass",
+        ":TransformUtils",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "VectorToArmSME",
     srcs = glob([

>From 8eae77374661ea4a30287f8d82de7595568fdc3f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 20 Mar 2025 21:30:10 +0000
Subject: [PATCH 5/5] Amend from conversion to dialect rewrite pattern

---
 mlir/include/mlir/Conversion/Passes.h         |  1 -
 mlir/include/mlir/Conversion/Passes.td        | 10 -------
 .../VectorToAMDGPU/VectorToAMDGPU.h           | 24 ---------------
 .../mlir/Dialect/AMDGPU/Transforms/Passes.h   |  4 +++
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  | 14 +++++++++
 mlir/lib/Conversion/CMakeLists.txt            |  1 -
 .../Conversion/VectorToAMDGPU/CMakeLists.txt  | 18 ------------
 .../Dialect/AMDGPU/Transforms/CMakeLists.txt  |  1 +
 .../AMDGPU/Transforms/TransferReadToLoad.cpp} | 29 +++++++++++--------
 .../AMDGPU/transfer-read-to-load.mlir}        | 17 ++++++++++-
 .../llvm-project-overlay/mlir/BUILD.bazel     | 22 --------------
 11 files changed, 52 insertions(+), 89 deletions(-)
 delete mode 100644 mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
 delete mode 100644 mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
 rename mlir/lib/{Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp => Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp} (87%)
 rename mlir/test/{Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir => Dialect/AMDGPU/transfer-read-to-load.mlir} (78%)

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ed5e8de8787f7..ccd862f67c068 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -73,7 +73,6 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
-#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d8f00c7109a3d..bbba495e613b2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1333,16 +1333,6 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
   let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
 }
 
-//===----------------------------------------------------------------------===//
-// VectorToAMDGPU
-//===----------------------------------------------------------------------===//
-
-def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
-  let summary = "Lower the operations from the vector dialect into the AMDGPU "
-                "dialect";
-  let dependentDialects = ["vector::VectorDialect"];
-} 
-
 //===----------------------------------------------------------------------===//
 // ArmSMEToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
deleted file mode 100644
index be96061a23b08..0000000000000
--- a/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
+++ /dev/null
@@ -1,24 +0,0 @@
-//===- VectorToAMDGPU.h - Vector to AMDGPU dialect conversion ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
-#define MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-class RewritePatternSet;
-class Pass;
-
-#define GEN_PASS_DECL_CONVERTVECTORTOAMDGPUPASS
-#include "mlir/Conversion/Passes.h.inc"
-
-void populateVectorToAMDGPUConversionPatterns(RewritePatternSet &patterns);
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index c3ae7930e8ec8..94dd9e3a29331 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,6 +22,7 @@ namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -30,6 +31,9 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
                                           Chipset chipset);
 
 void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
+
+void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 6d0bcd6e1066e..cfb7a2656144e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,4 +51,18 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
   ];
 }
 
+def AmdgpuTransferReadToLoadPass : Pass<"convert-transfer-read-to-load"> {
+  let summary = "Lower the operations from the vector transfer_read to vector load";
+  let description = [{
+    This pass creates a transfer read op lowering. A vector trasfer read op
+    will be lowered to a combination of vector.load, arith.select and
+    vector.broadcast.
+
+    This pattern will make it possible for masked transfer_read to be lowered
+    towards buffer load with bounds check, allowing a more optimized global
+    load accessing pattern compared with existing implementation of
+    llvm.intr.masked.load on vectors.
+  }];
+  let dependentDialects = [];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 1e4cbd2be4c96..b6c21440c571c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -66,7 +66,6 @@ add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
 add_subdirectory(UBToSPIRV)
-add_subdirectory(VectorToAMDGPU)
 add_subdirectory(VectorToArmSME)
 add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
deleted file mode 100644
index 2ad46c26d0a57..0000000000000
--- a/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_mlir_conversion_library(MLIRVectorToAMDGPU
-  VectorToAMDGPU.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMDGPU
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRAMDGPUDialect
-  MLIRVectorDialect
-  MLIRPass
-  MLIRTransforms
-  )
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 3d4567bff1e32..bc5b6e9186449 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
   ResolveStridedMetadata.cpp
+  TransferReadToLoad.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
similarity index 87%
rename from mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
rename to mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index b923546d7aaa5..3c1a2eb962037 100644
--- a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -1,4 +1,4 @@
-//===- VectorToAMDGPU.cpp - Vector to AMDGPU dialect conversion ---------===//
+//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -17,12 +17,13 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
 
 using namespace mlir;
+using namespace mlir::amdgpu;
 
 /// This pattern supports lowering of:
 /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
@@ -55,8 +56,11 @@ static LogicalResult transferPreconditions(
     return rewriter.notifyMatchFailure(xferOp, "not a memref source");
 
   Attribute addrSpace = memRefType.getMemorySpace();
-  if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-                        amdgpu::AddressSpace::FatRawBuffer)
+  if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
+    return rewriter.notifyMatchFailure(xferOp, "no address space");
+
+  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+      amdgpu::AddressSpace::FatRawBuffer)
     return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
 
   // Non-unit strides are handled by VectorToSCF.
@@ -134,16 +138,17 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
 } // namespace
 
-void mlir::populateVectorToAMDGPUConversionPatterns(
+void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering>(patterns.getContext());
 }
 
-struct ConvertVectorToAMDGPUPass final
-    : impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+struct AmdgpuTransferReadToLoadPass final
+    : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
+          AmdgpuTransferReadToLoadPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorToAMDGPUConversionPatterns(patterns);
+    populateAmdgpuTransferReadToLoadPatterns(patterns);
     walkAndApplyPatterns(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
similarity index 78%
rename from mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
rename to mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index d0a79045c86da..c39ce245deacd 100644
--- a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-amdgpu --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-transfer-read-to-load --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
@@ -32,6 +32,21 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
 
 // -----
 
+// CHECK-LABEL: func @transfer_to_maskedload_addrspace(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space<workgroup>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+
+// -----
+
 // CHECK-LABEL: func @transfer_broadcasting(
 // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 978779ea2bfb8..95e2788906525 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4466,7 +4466,6 @@ cc_library(
         ":TosaToTensor",
         ":UBToLLVM",
         ":UBToSPIRV",
-        ":VectorToAMDGPU",
         ":VectorToArmSME",
         ":VectorToGPU",
         ":VectorToLLVM",
@@ -12190,27 +12189,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "VectorToAMDGPU",
-    srcs = glob([
-        "lib/Conversion/VectorToAMDGPU/*.cpp",
-    ]),
-    hdrs = glob([
-        "include/mlir/Conversion/VectorToAMDGPU/*.h",
-    ]),
-    includes = ["include"],
-    deps = [
-        ":AMDGPUDialect",
-        ":VectorDialect",
-        ":ConversionPassIncGen",
-        ":IR",
-        ":MemRefDialect",
-        ":Pass",
-        ":TransformUtils",
-        "//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "VectorToArmSME",
     srcs = glob([



More information about the Mlir-commits mailing list