[Mlir-commits] [mlir] [MLIR][AMDGPU] Adding Vector to AMDGPU conversion lowering (PR #131803)
Zhuoran Yin
llvmlistbot at llvm.org
Wed Mar 19 06:51:18 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/131803
>From 7f6d6efa5a9b3fab973bfc915037577425949c68 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 18 Mar 2025 13:02:00 +0000
Subject: [PATCH 1/3] Adding Vector to AMDGPU conversion lowering
---
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 10 ++
.../VectorToAMDGPU/VectorToAMDGPU.h | 24 +++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/VectorToAMDGPU/CMakeLists.txt | 18 +++
.../VectorToAMDGPU/VectorToAMDGPU.cpp | 147 ++++++++++++++++++
.../vector-transfer-read-to-vector-load.mlir | 68 ++++++++
7 files changed, 269 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
create mode 100644 mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
create mode 100644 mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ccd862f67c068..ed5e8de8787f7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -73,6 +73,7 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..1845d0235183e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1333,6 +1333,16 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
}
+//===----------------------------------------------------------------------===//
+// VectorToAMDGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
+ let summary = "Lower the operations from the vector dialect into the AMDGPU "
+ "dialect";
+ let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArmSMEToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
new file mode 100644
index 0000000000000..be96061a23b08
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
@@ -0,0 +1,24 @@
+//===- VectorToAMDGPU.h - Vector to AMDGPU dialect conversion ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+#define MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateVectorToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b6c21440c571c..1e4cbd2be4c96 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -66,6 +66,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMDGPU)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..2ad46c26d0a57
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRVectorToAMDGPU
+ VectorToAMDGPU.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMDGPU
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRAMDGPUDialect
+ MLIRVectorDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
new file mode 100644
index 0000000000000..248b84a7fdc98
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -0,0 +1,147 @@
+//===- VectorToAMDGPU.cpp - Vector to AMDGPU dialect conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+/// This pattern supports lowering of:
+/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
+/// `vector.broadcast` if all of the following hold:
+/// - The transfer op is masked.
+/// - The memref is in buffer address space.
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+/// result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
+/// pass.
+static LogicalResult
+transferPreconditions(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp,
+ SmallVector<unsigned> &broadcastedDims,
+ VectorType &unbroadcastedVectorType) {
+ if (!xferOp.getMask())
+ return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
+
+ // Permutations are handled by VectorToSCF or
+ // populateVectorTransferPermutationMapLoweringPatterns.
+ // We let the 0-d corner case pass-through as it is supported.
+ if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
+ &broadcastedDims))
+ return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
+
+ auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(xferOp, "not a memref source");
+
+ Attribute addrSpace = memRefType.getMemorySpace();
+ if (!addrSpace ||
+ llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+ amdgpu::AddressSpace::FatRawBuffer)
+ return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
+
+ // Non-unit strides are handled by VectorToSCF.
+ if (!memRefType.isLastDimUnitStride())
+ return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
+
+ // If there is broadcasting involved then we first load the unbroadcasted
+ // vector, and then broadcast it with `vector.broadcast`.
+ ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+ SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
+ for (unsigned i : broadcastedDims)
+ unbroadcastedVectorShape[i] = 1;
+ unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
+ unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
+
+ // `vector.load` supports vector types as memref's elements only when the
+ // resulting vector type is the same as the element type.
+ auto memrefElTy = memRefType.getElementType();
+ if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
+ return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
+
+ // Otherwise, element types of the memref and the vector must match.
+ if (!isa<VectorType>(memrefElTy) &&
+ memrefElTy != xferOp.getVectorType().getElementType())
+ return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
+
+ // Out-of-bounds dims are handled by MaterializeTransferMask.
+ if (xferOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
+
+ if (xferOp.getVectorType().getRank() != 1)
+ // vector.maskedload operates on 1-D vectors.
+ return rewriter.notifyMatchFailure(
+ xferOp, "vector type is not rank 1, can't create masked load, needs "
+ "VectorToSCF");
+
+ return success();
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+
+ SmallVector<unsigned> broadcastedDims;
+ VectorType unbroadcastedVectorType;
+ if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
+ unbroadcastedVectorType))) {
+ return failure();
+ }
+
+ Value fill = rewriter.create<vector::SplatOp>(
+ readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
+ Value load = rewriter.create<vector::LoadOp>(
+ readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
+ readOp.getIndices());
+ Value res = rewriter.create<arith::SelectOp>(
+ readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
+
+ // Insert a broadcasting op if required.
+ if (!broadcastedDims.empty()) {
+ res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
+ readOp.getVectorType(), res);
+ }
+
+ rewriter.replaceOp(readOp, res);
+
+ return success();
+ }
+};
+
+void mlir::populateVectorToAMDGPUConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadLowering>(patterns.getContext());
+}
+
+struct ConvertVectorToAMDGPUPass
+ : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToAMDGPUConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
new file mode 100644
index 0000000000000..30d9814cc0621
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_regular(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_broadcasting(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+#broadcast_1d = affine_map<(d0, d1) -> (0)>
+func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+ {in_bounds = [true], permutation_map = #broadcast_1d}
+ : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
+// CHECK: return %[[BROADCAST]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_scalar(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+ {in_bounds = [true]}
+ : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
+ return %res : vector<1xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<1xf32>
>From 4610b01665b363c00c14d8a3674e201dd214c435 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 18 Mar 2025 21:00:53 +0000
Subject: [PATCH 2/3] Addressing review feedbacks
---
.../VectorToAMDGPU/VectorToAMDGPU.cpp | 48 +++++++++----------
.../vector-transfer-read-to-vector-load.mlir | 17 ++++---
2 files changed, 33 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
index 248b84a7fdc98..569e8f8bb4e3a 100644
--- a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -15,7 +15,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
@@ -36,17 +36,16 @@ using namespace mlir;
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
/// pass.
-static LogicalResult
-transferPreconditions(PatternRewriter &rewriter,
- VectorTransferOpInterface xferOp,
- SmallVector<unsigned> &broadcastedDims,
- VectorType &unbroadcastedVectorType) {
+static LogicalResult transferPreconditions(
+ PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
+ bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
if (!xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
+ SmallVector<unsigned> broadcastedDims;
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
@@ -56,9 +55,8 @@ transferPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
Attribute addrSpace = memRefType.getMemorySpace();
- if (!addrSpace ||
- llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
- amdgpu::AddressSpace::FatRawBuffer)
+ if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+ amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
// Non-unit strides are handled by VectorToSCF.
@@ -73,6 +71,7 @@ transferPreconditions(PatternRewriter &rewriter,
unbroadcastedVectorShape[i] = 1;
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
+ requiresBroadcasting = !broadcastedDims.empty();
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
@@ -98,31 +97,31 @@ transferPreconditions(PatternRewriter &rewriter,
return success();
}
-struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- SmallVector<unsigned> broadcastedDims;
+ bool requiresBroadcasting = false;
VectorType unbroadcastedVectorType;
- if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
+ if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
unbroadcastedVectorType))) {
return failure();
}
- Value fill = rewriter.create<vector::SplatOp>(
- readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
+ Location loc = readOp.getLoc();
+ Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
Value load = rewriter.create<vector::LoadOp>(
- readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
- readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(
- readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
+ loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+ Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
// Insert a broadcasting op if required.
- if (!broadcastedDims.empty()) {
- res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
- readOp.getVectorType(), res);
+ if (requiresBroadcasting) {
+ res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+ res);
}
rewriter.replaceOp(readOp, res);
@@ -136,12 +135,11 @@ void mlir::populateVectorToAMDGPUConversionPatterns(
patterns.add<TransferReadLowering>(patterns.getContext());
}
-struct ConvertVectorToAMDGPUPass
+struct ConvertVectorToAMDGPUPass final
: public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToAMDGPUConversionPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
- return signalPassFailure();
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
index 30d9814cc0621..d0a79045c86da 100644
--- a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
+++ b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-amdgpu --split-input-file | FileCheck %s
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
@@ -9,9 +9,10 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: return %[[SELECT]] : vector<4xf32>
// -----
@@ -43,9 +44,10 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
// CHECK: return %[[BROADCAST]] : vector<4xf32>
@@ -62,7 +64,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
return %res : vector<1xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: return %[[SELECT]] : vector<1xf32>
>From f7ca23b6d3dd559bac386dcb65cde1dfe7ee1d71 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Wed, 19 Mar 2025 13:29:31 +0000
Subject: [PATCH 3/3] Addressing review feedbacks
---
mlir/include/mlir/Conversion/Passes.td | 2 +-
mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp | 6 +++++-
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1845d0235183e..d8f00c7109a3d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1340,7 +1340,7 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
let summary = "Lower the operations from the vector dialect into the AMDGPU "
"dialect";
- let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+ let dependentDialects = ["vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
index 569e8f8bb4e3a..b923546d7aaa5 100644
--- a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -97,6 +97,8 @@ static LogicalResult transferPreconditions(
return success();
}
+namespace {
+
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
@@ -130,13 +132,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
};
+} // namespace
+
void mlir::populateVectorToAMDGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering>(patterns.getContext());
}
struct ConvertVectorToAMDGPUPass final
- : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+ : impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToAMDGPUConversionPatterns(patterns);
More information about the Mlir-commits
mailing list