[Mlir-commits] [mlir] cab44c5 - [mlir][AMDGPU] Add --chipset option to AMDGPUToROCDL
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Jul 7 07:58:21 PDT 2022
Author: Krzysztof Drewniak
Date: 2022-07-07T14:58:13Z
New Revision: cab44c515c63db876a7f626fdc36be4a9157ebad
URL: https://github.com/llvm/llvm-project/commit/cab44c515c63db876a7f626fdc36be4a9157ebad
DIFF: https://github.com/llvm/llvm-project/commit/cab44c515c63db876a7f626fdc36be4a9157ebad.diff
LOG: [mlir][AMDGPU] Add --chipset option to AMDGPUToROCDL
Because the buffer descriptor structure (the V#) has no backwards-compatibility
guarentees, and since said guarantees have been violated in practice
(see https://github.com/llvm/llvm-project/issues/56323 ), and since
the `targetIsRDNA` attribute isn't something that higher-level clients can set
in general, make the lowering of the amdgpu dialect to rocdl take a --chipset
option.
Note that this option is a string because adding a parser for the Chipset
struct to llvm::cl wasn't working out.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D129228
Added:
mlir/include/mlir/Conversion/AMDGPUToROCDL/Chipset.h
mlir/lib/Conversion/AMDGPUToROCDL/Chipset.cpp
Modified:
mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/AMDGPUToROCDL/CMakeLists.txt
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index 1312fa720b49e..7906abafd7189 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -8,6 +8,7 @@
#ifndef MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_
#define MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_
+#include "mlir/Conversion/AMDGPUToROCDL/Chipset.h"
#include <memory>
namespace mlir {
@@ -17,7 +18,8 @@ class RewritePatternSet;
class Pass;
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ amdgpu::Chipset chipset);
std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/Chipset.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/Chipset.h
new file mode 100644
index 0000000000000..82808821b6e95
--- /dev/null
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/Chipset.h
@@ -0,0 +1,27 @@
+//===- Chipset.h - AMDGPU Chipset version struct ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_AMDGPUTOROCDL_CHIPSET_H_
+#define MLIR_CONVERSION_AMDGPUTOROCDL_CHIPSET_H_
+
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace amdgpu {
+struct Chipset {
+ Chipset() = default;
+ Chipset(unsigned majorVersion, unsigned minorVersion)
+ : majorVersion(majorVersion), minorVersion(minorVersion){};
+ static FailureOr<Chipset> parse(StringRef name);
+
+ unsigned majorVersion = 0;
+ unsigned minorVersion = 0;
+};
+} // end namespace amdgpu
+} // end namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 33b38cecf0fdd..2879cd5b6553c 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -39,6 +39,7 @@ void configureGpuToROCDLConversionLegality(ConversionTarget &target);
/// is configurable.
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createLowerGpuOpsToROCDLOpsPass(
+ const std::string &chipset = "gfx900",
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ebe9dd7645f7b..3098e7a73be44 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -87,6 +87,9 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
"LLVM::LLVMDialect",
"ROCDL::ROCDLDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
@@ -364,6 +367,9 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
let dependentDialects = ["ROCDL::ROCDLDialect"];
let options = [
+ Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
index 746970ccc8d54..f1e1bb3df9502 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
@@ -37,7 +37,6 @@ def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
AttrSizedOperandSegments]>,
Arguments<(ins Arg<AnyMemRef, "buffer to load from", [MemRead]>:$memref,
- BoolAttr:$targetIsRDNA,
Variadic<I32>:$indices,
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
OptionalAttr<I32Attr>:$indexOffset,
@@ -71,11 +70,9 @@ def AMDGPU_RawBufferLoadOp :
as max_d (size(d) * stride(d)) * sizeof(elementType(memref))
- The offset enable bit is 1, the index enable bit is 0.
- The thread ID addition bit is off
- - If `boundsCheck` is false and the target is RDNA, OOB_SELECT is set to 2
- to disable bounds checks, otherwise it is 0
+ - If `boundsCheck` is false and the target chipset is RDNA, OOB_SELECT is set
+ to 2 to disable bounds checks, otherwise it is 0
- The cache coherency bits are off
- - `targetIsRDNA` controls the setting of some reserved values that
diff er
- between RDNA and CDNA cores
}];
let assemblyFormat = [{
attr-dict $memref `[` $indices `]`
@@ -94,7 +91,6 @@ def AMDGPU_RawBufferStoreOp :
VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>,
VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value,
Arg<AnyMemRef, "buffer to store to", [MemWrite]>:$memref,
- BoolAttr:$targetIsRDNA,
Variadic<I32>:$indices,
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
OptionalAttr<I32Attr>:$indexOffset,
@@ -132,7 +128,6 @@ def AMDGPU_RawBufferAtomicFaddOp :
AttrSizedOperandSegments]>,
Arguments<(ins F32:$value,
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
- BoolAttr:$targetIsRDNA,
Variadic<I32>:$indices,
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
OptionalAttr<I32Attr>:$indexOffset,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 56d4967b608c6..f7e99e3731e89 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
using namespace mlir;
+using namespace mlir::amdgpu;
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
@@ -26,8 +27,10 @@ namespace {
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
- using ConvertOpToLLVMPattern<GpuOp>::ConvertOpToLLVMPattern;
+ RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
+ Chipset chipset;
static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
@@ -38,6 +41,9 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Value unconvertedMemref = gpuOp.getMemref();
MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>();
+ if (chipset.majorVersion < 9)
+ return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
+
Value storeData = adaptor.getODSOperands(0)[0];
if (storeData == memref) // no write component to this op
storeData = Value();
@@ -57,7 +63,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// If we want to load a vector<NxT> with total size <= 32
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
- // and the
+ // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
+ // 32) x i32 and bitcast.
Type llvmBufferValType = llvmWantedDataType;
if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
@@ -163,7 +170,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// swizzles) RDNA only
// bits 30-31: Type (must be 0)
uint32_t word3 = (7 << 12) | (4 << 15);
- if (adaptor.getTargetIsRDNA()) {
+ if (chipset.majorVersion == 10) {
word3 |= (1 << 24);
uint32_t oob = adaptor.getBoundsCheck() ? 1 : 2;
word3 |= (oob << 28);
@@ -239,9 +246,16 @@ struct ConvertAMDGPUToROCDLPass
ConvertAMDGPUToROCDLPass() = default;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- LLVMTypeConverter converter(&getContext());
- populateAMDGPUToROCDLConversionPatterns(converter, patterns);
+ MLIRContext *ctx = &getContext();
+ FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+ return signalPassFailure();
+ }
+
+ RewritePatternSet patterns(ctx);
+ LLVMTypeConverter converter(ctx);
+ populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
@@ -252,13 +266,14 @@ struct ConvertAMDGPUToROCDLPass
};
} // namespace
-void mlir::populateAMDGPUToROCDLConversionPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ Chipset chipset) {
patterns.add<
- RawBufferOpLowering<amdgpu::RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
- RawBufferOpLowering<amdgpu::RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
- RawBufferOpLowering<amdgpu::RawBufferAtomicFaddOp,
- ROCDL::RawBufferAtomicFAddOp>>(converter);
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
+ RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>(
+ converter, chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/AMDGPUToROCDL/CMakeLists.txt
index db1d19d635267..f574b4c55dd83 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRAMDGPUToROCDL
AMDGPUToROCDL.cpp
+ Chipset.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AMDGPUToROCDL
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/Chipset.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/Chipset.cpp
new file mode 100644
index 0000000000000..99087e1c24209
--- /dev/null
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/Chipset.cpp
@@ -0,0 +1,28 @@
+//===- Chipset.cpp - AMDGPU Chipset version struct parsing -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AMDGPUToROCDL/Chipset.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+FailureOr<Chipset> Chipset::parse(StringRef name) {
+ if (!name.startswith("gfx"))
+ return failure();
+ unsigned major = 0;
+ unsigned minor = 0;
+ StringRef majorRef = name.drop_front(3).drop_back(2);
+ StringRef minorRef = name.take_back(2);
+ if (majorRef.getAsInteger(10, major))
+ return failure();
+ if (minorRef.getAsInteger(16, minor))
+ return failure();
+ return Chipset(major, minor);
+}
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index a149dda2be39b..1704f6ede4ef5 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -55,37 +55,46 @@ namespace {
struct LowerGpuOpsToROCDLOpsPass
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
LowerGpuOpsToROCDLOpsPass() = default;
- LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth, gpu::amd::Runtime runtime) {
+ LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
+ gpu::amd::Runtime runtime) {
+ this->chipset = chipset;
this->indexBitwidth = indexBitwidth;
this->runtime = runtime;
}
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
+ MLIRContext *ctx = m.getContext();
// Request C wrapper emission.
for (auto func : m.getOps<func::FuncOp>()) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
- UnitAttr::get(&getContext()));
+ UnitAttr::get(ctx));
+ }
+
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+ return signalPassFailure();
}
/// Customize the bitwidth used for the device side index computations.
LowerToLLVMOptions options(
- m.getContext(),
- DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
+ ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
- LLVMTypeConverter converter(m.getContext(), options);
+ LLVMTypeConverter converter(ctx, options);
- RewritePatternSet patterns(m.getContext());
- RewritePatternSet llvmPatterns(m.getContext());
+ RewritePatternSet patterns(ctx);
+ RewritePatternSet llvmPatterns(ctx);
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
llvmPatterns);
- populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns);
+ populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
+ *maybeChipset);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToROCDLConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
@@ -180,7 +189,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth,
+mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
+ unsigned indexBitwidth,
gpu::amd::Runtime runtime) {
- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth, runtime);
+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>(chipset, indexBitwidth,
+ runtime);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 91043a660d635..f3f4a29aa625f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -1,34 +1,27 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefix=RDNA
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
// CHECK: llvm.insertelement{{.*}}%[[numRecords]]
// CHECK: %[[word3:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[word3:.*]] = llvm.mlir.constant(285372416 : i32)
// CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
// CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
// CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = false} %buf[%idx] : memref<64xi32>, i32 -> i32
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> i32
func.return %0 : i32
}
-// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_rdna
-func.func @gpu_gcn_raw_buffer_load_i32_rdna(%buf: memref<64xi32>, %idx: i32) -> i32 {
- // CHECK: %[[word3:.*]] = llvm.mlir.constant(285372416 : i32)
- // CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
- // CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
- // CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = true} %buf[%idx] : memref<64xi32>, i32 -> i32
- func.return %0 : i32
-}
-
-// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_rdna_oob_off
-func.func @gpu_gcn_raw_buffer_load_i32_rdna_oob_off(%buf: memref<64xi32>, %idx: i32) -> i32 {
- // CHECK: %[[word3:.*]] = llvm.mlir.constant(553807872 : i32)
- // CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
- // CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
- // CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = false, targetIsRDNA = true} %buf[%idx] : memref<64xi32>, i32 -> i32
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_oob_off
+func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32) -> i32 {
+ // CHECK: %[[word3:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[word3:.*]] = llvm.mlir.constant(553807872 : i32)
+ // RDNA: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
+ // RDNA: %[[ret:.*]] = rocdl.raw.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+ // RDNA: return %[[ret]]
+ %0 = amdgpu.raw_buffer_load {boundsCheck = false} %buf[%idx] : memref<64xi32>, i32 -> i32
func.return %0 : i32
}
@@ -36,7 +29,7 @@ func.func @gpu_gcn_raw_buffer_load_i32_rdna_oob_off(%buf: memref<64xi32>, %idx:
func.func @gpu_gcn_raw_buffer_load_2xi32(%buf: memref<64xi32>, %idx: i32) -> vector<2xi32> {
// CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi32>
// CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = false} %buf[%idx] : memref<64xi32>, i32 -> vector<2xi32>
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> vector<2xi32>
func.return %0 : vector<2xi32>
}
@@ -46,7 +39,7 @@ func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
// CHECK: llvm.insertelement{{.*}}%[[numRecords]]
// CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i8
// CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = false} %buf[%idx] : memref<64xi8>, i32 -> i8
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> i8
func.return %0 : i8
}
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8
@@ -56,7 +49,7 @@ func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vecto
// CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i16
// CHECK: %[[ret:.*]] = llvm.bitcast %[[loaded]] : i16 to vector<2xi8>
// CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = false} %buf[%idx] : memref<64xi8>, i32 -> vector<2xi8>
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> vector<2xi8>
func.return %0 : vector<2xi8>
}
@@ -65,7 +58,7 @@ func.func @gpu_gcn_raw_buffer_load_16xi8(%buf: memref<64xi8>, %idx: i32) -> vect
// CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi32>
// CHECK: %[[ret:.*]] = llvm.bitcast %[[loaded]] : vector<4xi32> to vector<16xi8>
// CHECK: return %[[ret]]
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, targetIsRDNA = false} %buf[%idx] : memref<64xi8>, i32 -> vector<16xi8>
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> vector<16xi8>
func.return %0 : vector<16xi8>
}
@@ -77,7 +70,7 @@ func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx:
// CHECK: %[[word3:.*]] = llvm.mlir.constant(159744 : i32)
// CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
// CHECK: rocdl.raw.buffer.store %{{.*}} %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
- amdgpu.raw_buffer_store {boundsCheck = true, targetIsRDNA = false} %value -> %buf[%idx] : i32 -> memref<64xi32>, i32
+ amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : i32 -> memref<64xi32>, i32
func.return
}
@@ -85,7 +78,7 @@ func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx:
func.func @gpu_gcn_raw_buffer_store_2xi8(%value: vector<2xi8>, %buf: memref<64xi8>, %idx: i32) {
// CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<2xi8> to i16
// CHECK: rocdl.raw.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i16
- amdgpu.raw_buffer_store {boundsCheck = true, targetIsRDNA = false} %value -> %buf[%idx] : vector<2xi8> -> memref<64xi8>, i32
+ amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<2xi8> -> memref<64xi8>, i32
func.return
}
@@ -93,7 +86,7 @@ func.func @gpu_gcn_raw_buffer_store_2xi8(%value: vector<2xi8>, %buf: memref<64xi
func.func @gpu_gcn_raw_buffer_store_16xi8(%value: vector<16xi8>, %buf: memref<64xi8>, %idx: i32) {
// CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: rocdl.raw.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi32>
- amdgpu.raw_buffer_store {boundsCheck = true, targetIsRDNA = false} %value -> %buf[%idx] : vector<16xi8> -> memref<64xi8>, i32
+ amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<16xi8> -> memref<64xi8>, i32
func.return
}
@@ -105,6 +98,6 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>,
// CHECK: %[[word3:.*]] = llvm.mlir.constant(159744 : i32)
// CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]]
// CHECK: rocdl.raw.buffer.atomic.fadd %{{.*}} %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : f32
- amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, targetIsRDNA = false} %value -> %buf[%idx] : f32 -> memref<64xf32>, i32
+ amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : f32 -> memref<64xf32>, i32
func.return
}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index e4d6ebab6889f..daf6b7a0dae0b 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -6,56 +6,56 @@
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
- // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %src[%idx0] sgprOffset %offset : memref<128xf32>, i32 -> f32
+ // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %src[%idx0] sgprOffset %offset : memref<128xf32>, i32 -> f32
func.return %0 : f32
}
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_4
func.func @raw_buffer_load_f32_from_rank_4(%src : memref<128x64x32x16xf32>, %offset : i32, %idx0 : i32, %idx1 : i32, %idx2 : i32, %idx3 : i32) -> f32 {
- // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> f32
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %src[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> f32
+ // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> f32
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %src[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> f32
func.return %0 : f32
}
// CHECK-LABEL: func @raw_buffer_load_4xf32_from_rank_4
func.func @raw_buffer_load_4xf32_from_rank_4(%src : memref<128x64x32x16xf32>, %offset : i32, %idx0 : i32, %idx1 : i32, %idx2 : i32, %idx3 : i32) -> vector<4xf32> {
- // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> vector<4xf32>
- %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %src[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> vector<4xf32>
+ // CHECK: amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> vector<4xf32>
+ %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 1 : i32} %src[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : memref<128x64x32x16xf32>, i32, i32, i32, i32 -> vector<4xf32>
func.return %0 : vector<4xf32>
}
// CHECK-LABEL: func @raw_buffer_store_f32_to_rank_1
func.func @raw_buffer_store_f32_to_rank_1(%value : f32, %dst : memref<128xf32>, %offset : i32, %idx0 : i32) {
- // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}} -> %{{.*}}[{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128xf32>, i32
- amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0] sgprOffset %offset : f32 -> memref<128xf32>, i32
+ // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %{{.*}} -> %{{.*}}[{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128xf32>, i32
+ amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %value -> %dst[%idx0] sgprOffset %offset : f32 -> memref<128xf32>, i32
func.return
}
// CHECK-LABEL: func @raw_buffer_store_f32_to_rank_4
func.func @raw_buffer_store_f32_to_rank_4(%value : f32, %dst : memref<128x64x32x16xf32>, %offset : i32, %idx0 : i32, %idx1 : i32, %idx2 : i32, %idx3 : i32) {
- // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
- amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
func.return
}
// CHECK-LABEL: func @raw_buffer_store_4xf32_to_rank_4
func.func @raw_buffer_store_4xf32_to_rank_4(%value : vector<4xf32>, %dst : memref<128x64x32x16xf32>, %offset : i32, %idx0 : i32, %idx1 : i32, %idx2 : i32, %idx3 : i32) {
- // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : vector<4xf32> -> memref<128x64x32x16xf32>, i32, i32, i32, i32
- amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : vector<4xf32> -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ // CHECK: amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : vector<4xf32> -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ amdgpu.raw_buffer_store {boundsCheck = true, indexOffset = 1 : i32} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : vector<4xf32> -> memref<128x64x32x16xf32>, i32, i32, i32, i32
func.return
}
// CHECK-LABEL: func @raw_buffer_atomic_fadd_f32_to_rank_1
func.func @raw_buffer_atomic_fadd_f32_to_rank_1(%value : f32, %dst : memref<128xf32>, %offset : i32, %idx0 : i32) {
- // CHECK: amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}} -> %{{.*}}[{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128xf32>, i32
- amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0] sgprOffset %offset : f32 -> memref<128xf32>, i32
+ // CHECK: amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32} %{{.*}} -> %{{.*}}[{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128xf32>, i32
+ amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32} %value -> %dst[%idx0] sgprOffset %offset : f32 -> memref<128xf32>, i32
func.return
}
// CHECK-LABEL: func @raw_buffer_atomic_fadd_f32_to_rank_4
func.func @raw_buffer_atomic_fadd_f32_to_rank_4(%value : f32, %dst : memref<128x64x32x16xf32>, %offset : i32, %idx0 : i32, %idx1 : i32, %idx2 : i32, %idx3 : i32) {
- // CHECK: amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
- amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ // CHECK: amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32} %{{.*}} -> %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] sgprOffset %{{.*}} : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
+ amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32
func.return
}
More information about the Mlir-commits
mailing list