[Mlir-commits] [mlir] [mlir][amdgpu] Align Chipset with TargetParser (PR #107720)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 7 14:25:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-gpu
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Update the Chipset struct to follow the `IsaVersion` definition from llvm's `TargetParser`. This is a follow up to https://github.com/llvm/llvm-project/pull/106169#discussion_r1733955012.
* Add the stepping version. Note: This may break downstream code that compares against the minor version directly.
* Use comparisons with full Chipset version where possible.
Note that we can't use the code in `TargetParser` directly because the chipset utility is outside of `mlir/Target` that re-exports llvm's target library.
---
Full diff: https://github.com/llvm/llvm-project/pull/107720.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h (+24-19)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+20-19)
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+1-1)
- (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+4-5)
- (modified) mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp (+6-2)
- (modified) mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp (+20-11)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 0e2708b1efae03..8ee7c03730e5d2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -9,39 +9,44 @@
#define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
#include "mlir/Support/LLVM.h"
-#include <utility>
+#include <tuple>
namespace mlir::amdgpu {
/// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
/// Note that the leading digits form a decimal number, while the last two
/// digits for a hexadecimal number. For example:
-/// gfx942 --> major = 9, minor = 0x42
-/// gfx90a --> major = 9, minor = 0xa
-/// gfx1103 --> major = 10, minor = 0x3
+/// gfx942 --> major = 9, minor = 0x4, stepping = 0x2
+/// gfx90a --> major = 9, minor = 0x0, stepping = 0xa
+/// gfx1103 --> major = 10, minor = 0x0, stepping = 0x3
struct Chipset {
- Chipset() = default;
- Chipset(unsigned majorVersion, unsigned minorVersion)
- : majorVersion(majorVersion), minorVersion(minorVersion){};
+ unsigned majorVersion = 0; // The major version (decimal).
+ unsigned minorVersion = 0; // The minor version (hexadecimal).
+ unsigned steppingVersion = 0; // The stepping version (hexadecimal).
+
+ constexpr Chipset() = default;
+ constexpr Chipset(unsigned major, unsigned minor, unsigned stepping)
+ : majorVersion(major), minorVersion(minor), steppingVersion(stepping){};
/// Parses the chipset version string and returns the chipset on success, and
/// failure otherwise.
static FailureOr<Chipset> parse(StringRef name);
- friend bool operator==(const Chipset &lhs, const Chipset &rhs) {
- return lhs.majorVersion == rhs.majorVersion &&
- lhs.minorVersion == rhs.minorVersion;
- }
- friend bool operator!=(const Chipset &lhs, const Chipset &rhs) {
- return !(lhs == rhs);
- }
- friend bool operator<(const Chipset &lhs, const Chipset &rhs) {
- return std::make_pair(lhs.majorVersion, lhs.minorVersion) <
- std::make_pair(rhs.majorVersion, rhs.minorVersion);
+ std::tuple<unsigned, unsigned, unsigned> asTuple() const {
+ return {majorVersion, minorVersion, steppingVersion};
}
- unsigned majorVersion = 0; // The major version (decimal).
- unsigned minorVersion = 0; // The minor version (hexadecimal).
+#define DEFINE_COMP_OPERATOR(OPERATOR) \
+ friend bool operator OPERATOR(const Chipset &lhs, const Chipset &rhs) { \
+ return lhs.asTuple() OPERATOR rhs.asTuple(); \
+ }
+ DEFINE_COMP_OPERATOR(==)
+ DEFINE_COMP_OPERATOR(!=)
+ DEFINE_COMP_OPERATOR(<)
+ DEFINE_COMP_OPERATOR(<=)
+ DEFINE_COMP_OPERATOR(>)
+ DEFINE_COMP_OPERATOR(>=)
+#undef DEFINE_COMP_OPERATOR
};
} // namespace mlir::amdgpu
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e407f1ca528d8..96b433294d258a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
}
namespace {
+// Define commonly used chipsets versions for convenience.
+constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
@@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LogicalResult
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- bool requiresInlineAsm =
- chipset.majorVersion < 9 ||
- (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
- (chipset.majorVersion == 11);
+ bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
if (requiresInlineAsm) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
@@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
destElem = destType.getElementType();
if (sourceElem.isF32() && destElem.isF32()) {
- if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
+ if (mfma.getReducePrecision() && chipset >= kGfx940) {
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
@@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
}
- if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
+ if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
- if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
+ if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
- if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
+ if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
}
- if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
+ if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 4)
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
@@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
@@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
if (outVecType.getElementType().isBF16())
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
- if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
+ if (chipset.majorVersion != 9 || chipset < kGfx908)
return op->emitOpError("MFMA only supported on gfx908+");
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
- if (chipset.minorVersion < 0x40)
- return op.emitOpError("negation unsupported on older than gfx840");
+ if (chipset < kGfx940)
+ return op.emitOpError("negation unsupported on older than gfx940");
getBlgpField |=
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
}
@@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+ if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+ if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+ if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..6b27ec9947cb0b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -384,7 +384,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+ maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index f89e2537897e80..21042aff529c9d 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -9,12 +9,12 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
@@ -146,13 +146,12 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
// gfx10 has no atomic adds.
- if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
- (chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
+ if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
target.addIllegalOp<RawBufferAtomicFaddOp>();
}
// gfx9 has no to a very limited support for floating-point min and max.
if (chipset.majorVersion == 9) {
- if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) {
+ if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
// gfx90a supports f64 max (and min, but we don't have a min wrapper right
// now) but all other types need to be emulated.
target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
@@ -162,7 +161,7 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
} else {
target.addIllegalOp<RawBufferAtomicFmaxOp>();
}
- if (chipset.minorVersion == 0x41) {
+ if (chipset == Chipset(9, 4, 1)) {
// gfx941 requires non-CAS atomics to be implemented with CAS loops.
// The workaround here mirrors HIP and OpenMP.
target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
diff --git a/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp b/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
index fd15879d7b7ea0..293738982e060c 100644
--- a/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
@@ -19,14 +19,18 @@ FailureOr<Chipset> Chipset::parse(StringRef name) {
unsigned major = 0;
unsigned minor = 0;
+ unsigned stepping = 0;
StringRef majorRef = name.drop_back(2);
- StringRef minorRef = name.take_back(2);
+ StringRef minorRef = name.take_back(2).drop_back(1);
+ StringRef steppingRef = name.take_back(1);
if (majorRef.getAsInteger(10, major))
return failure();
if (minorRef.getAsInteger(16, minor))
return failure();
- return Chipset(major, minor);
+ if (steppingRef.getAsInteger(16, stepping))
+ return failure();
+ return Chipset(major, minor, stepping);
}
} // namespace mlir::amdgpu
diff --git a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
index b08b6681235d3b..976ff2e7382edf 100644
--- a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
+++ b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
@@ -16,17 +16,20 @@ TEST(ChipsetTest, Parsing) {
FailureOr<Chipset> chipset = Chipset::parse("gfx90a");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 9u);
- EXPECT_EQ(chipset->minorVersion, 0x0au);
+ EXPECT_EQ(chipset->minorVersion, 0u);
+ EXPECT_EQ(chipset->steppingVersion, 0xau);
chipset = Chipset::parse("gfx940");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 9u);
- EXPECT_EQ(chipset->minorVersion, 0x40u);
+ EXPECT_EQ(chipset->minorVersion, 4u);
+ EXPECT_EQ(chipset->steppingVersion, 0u);
chipset = Chipset::parse("gfx1103");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 11u);
- EXPECT_EQ(chipset->minorVersion, 0x03u);
+ EXPECT_EQ(chipset->minorVersion, 0u);
+ EXPECT_EQ(chipset->steppingVersion, 3u);
}
TEST(ChipsetTest, ParsingInvalid) {
@@ -43,14 +46,20 @@ TEST(ChipsetTest, ParsingInvalid) {
}
TEST(ChipsetTest, Comparison) {
- EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40));
- EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42));
- EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00));
-
- EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00));
- EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42));
- EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42));
- EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40));
+ EXPECT_EQ(Chipset(9, 4, 0), Chipset(9, 4, 0));
+ EXPECT_NE(Chipset(9, 4, 0), Chipset(9, 4, 2));
+ EXPECT_NE(Chipset(9, 0, 0), Chipset(10, 0, 0));
+
+ EXPECT_LT(Chipset(9, 0, 0), Chipset(10, 0, 0));
+ EXPECT_LT(Chipset(9, 0, 0), Chipset(9, 4, 2));
+ EXPECT_LE(Chipset(9, 4, 1), Chipset(9, 4, 1));
+ EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 2));
+ EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 0));
+
+ EXPECT_GT(Chipset(9, 0, 0xa), Chipset(9, 0, 8));
+ EXPECT_GE(Chipset(9, 0, 0xa), Chipset(9, 0, 0xa));
+ EXPECT_FALSE(Chipset(9, 4, 1) >= Chipset(9, 4, 2));
+ EXPECT_FALSE(Chipset(9, 0, 0xa) >= Chipset(9, 4, 0));
}
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/107720
More information about the Mlir-commits
mailing list