[Mlir-commits] [mlir] 763bc92 - [mlir][amdgpu] Align Chipset with TargetParser (#107720)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 08:12:31 PDT 2024


Author: Jakub Kuderski
Date: 2024-09-09T11:12:26-04:00
New Revision: 763bc9249cf0b7da421182e24716d9a569fb5184

URL: https://github.com/llvm/llvm-project/commit/763bc9249cf0b7da421182e24716d9a569fb5184
DIFF: https://github.com/llvm/llvm-project/commit/763bc9249cf0b7da421182e24716d9a569fb5184.diff

LOG: [mlir][amdgpu] Align Chipset with TargetParser (#107720)

Update the Chipset struct to follow the `IsaVersion` definition from
llvm's `TargetParser`. This is a follow up to
https://github.com/llvm/llvm-project/pull/106169#discussion_r1733955012.

* Add the stepping version. Note: This may break downstream code that
compares against the minor version directly.
* Use comparisons with full Chipset version where possible.

Note that we can't use the code in `TargetParser` directly because the
chipset utility is outside of `mlir/Target` that re-exports llvm's
target library.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
    mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
    mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
    mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 0e2708b1efae03..a5dab1ab896302 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -9,39 +9,44 @@
 #define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
 
 #include "mlir/Support/LLVM.h"
-#include <utility>
+#include <tuple>
 
 namespace mlir::amdgpu {
 
 /// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
 /// Note that the leading digits form a decimal number, while the last two
 /// digits for a hexadecimal number. For example:
-///   gfx942  --> major = 9, minor = 0x42
-///   gfx90a  --> major = 9, minor = 0xa
-///   gfx1103 --> major = 10, minor = 0x3
+///   gfx942  --> major = 9, minor = 0x4, stepping = 0x2
+///   gfx90a  --> major = 9, minor = 0x0, stepping = 0xa
+///   gfx1103 --> major = 10, minor = 0x0, stepping = 0x3
 struct Chipset {
-  Chipset() = default;
-  Chipset(unsigned majorVersion, unsigned minorVersion)
-      : majorVersion(majorVersion), minorVersion(minorVersion){};
+  unsigned majorVersion = 0;    // The major version (decimal).
+  unsigned minorVersion = 0;    // The minor version (hexadecimal).
+  unsigned steppingVersion = 0; // The stepping version (hexadecimal).
+
+  constexpr Chipset() = default;
+  constexpr Chipset(unsigned major, unsigned minor, unsigned stepping)
+      : majorVersion(major), minorVersion(minor), steppingVersion(stepping) {};
 
   /// Parses the chipset version string and returns the chipset on success, and
   /// failure otherwise.
   static FailureOr<Chipset> parse(StringRef name);
 
-  friend bool operator==(const Chipset &lhs, const Chipset &rhs) {
-    return lhs.majorVersion == rhs.majorVersion &&
-           lhs.minorVersion == rhs.minorVersion;
-  }
-  friend bool operator!=(const Chipset &lhs, const Chipset &rhs) {
-    return !(lhs == rhs);
-  }
-  friend bool operator<(const Chipset &lhs, const Chipset &rhs) {
-    return std::make_pair(lhs.majorVersion, lhs.minorVersion) <
-           std::make_pair(rhs.majorVersion, rhs.minorVersion);
+  std::tuple<unsigned, unsigned, unsigned> asTuple() const {
+    return {majorVersion, minorVersion, steppingVersion};
   }
 
-  unsigned majorVersion = 0; // The major version (decimal).
-  unsigned minorVersion = 0; // The minor version (hexadecimal).
+#define DEFINE_COMP_OPERATOR(OPERATOR)                                         \
+  friend bool operator OPERATOR(const Chipset &lhs, const Chipset &rhs) {      \
+    return lhs.asTuple() OPERATOR rhs.asTuple();                               \
+  }
+  DEFINE_COMP_OPERATOR(==)
+  DEFINE_COMP_OPERATOR(!=)
+  DEFINE_COMP_OPERATOR(<)
+  DEFINE_COMP_OPERATOR(<=)
+  DEFINE_COMP_OPERATOR(>)
+  DEFINE_COMP_OPERATOR(>=)
+#undef DEFINE_COMP_OPERATOR
 };
 
 } // namespace mlir::amdgpu

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e407f1ca528d8..96b433294d258a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
 }
 
 namespace {
+// Define commonly used chipsets versions for convenience.
+constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
 /// Define lowering patterns for raw buffer ops
 template <typename GpuOp, typename Intrinsic>
 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
@@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
   LogicalResult
   matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    bool requiresInlineAsm =
-        chipset.majorVersion < 9 ||
-        (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
-        (chipset.majorVersion == 11);
+    bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
 
     if (requiresInlineAsm) {
       auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
@@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
     destElem = destType.getElementType();
 
   if (sourceElem.isF32() && destElem.isF32()) {
-    if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
+    if (mfma.getReducePrecision() && chipset >= kGfx940) {
       if (m == 32 && n == 32 && k == 4 && b == 1)
         return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
       if (m == 16 && n == 16 && k == 8 && b == 1)
@@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f32_16x16x16f16::getOperationName();
   }
 
-  if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
+  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
     if (m == 32 && n == 32 && k == 4 && b == 2)
       return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
     if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_i32_32x32x8i8::getOperationName();
     if (m == 16 && n == 16 && k == 16 && b == 1)
       return ROCDL::mfma_i32_16x16x16i8::getOperationName();
-    if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
+    if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
       return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
-    if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
+    if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
       return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
   }
 
-  if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
+  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
     if (m == 16 && n == 16 && k == 4 && b == 1)
       return ROCDL::mfma_f64_16x16x4f64::getOperationName();
     if (m == 4 && n == 4 && k == 4 && b == 4)
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
@@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
@@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
       if (outVecType.getElementType().isBF16())
         intrinsicOutType = outVecType.clone(rewriter.getI16Type());
 
-    if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
+    if (chipset.majorVersion != 9 || chipset < kGfx908)
       return op->emitOpError("MFMA only supported on gfx908+");
     uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
     if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
-      if (chipset.minorVersion < 0x40)
-        return op.emitOpError("negation unsupported on older than gfx840");
+      if (chipset < kGfx940)
+        return op.emitOpError("negation unsupported on older than gfx940");
       getBlgpField |=
           op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
     }
@@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
     ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+  if (chipset.majorVersion != 9 || chipset < kGfx940)
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");
@@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
     PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+  if (chipset.majorVersion != 9 || chipset < kGfx940)
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");
@@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
     PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+  if (chipset.majorVersion != 9 || chipset < kGfx940)
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");

diff  --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..6b27ec9947cb0b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -384,7 +384,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
   }
 
   bool convertFP8Arithmetic =
-      (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+      maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
   arith::populateArithToAMDGPUConversionPatterns(
       patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
       *maybeChipset);

diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index f89e2537897e80..21042aff529c9d 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -9,12 +9,12 @@
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir::amdgpu {
 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
@@ -146,13 +146,12 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
 void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
     ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
   // gfx10 has no atomic adds.
-  if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
-      (chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
+  if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
     target.addIllegalOp<RawBufferAtomicFaddOp>();
   }
   // gfx9 has no to a very limited support for floating-point min and max.
   if (chipset.majorVersion == 9) {
-    if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) {
+    if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
       // gfx90a supports f64 max (and min, but we don't have a min wrapper right
       // now) but all other types need to be emulated.
       target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
@@ -162,7 +161,7 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
     } else {
       target.addIllegalOp<RawBufferAtomicFmaxOp>();
     }
-    if (chipset.minorVersion == 0x41) {
+    if (chipset == Chipset(9, 4, 1)) {
       // gfx941 requires non-CAS atomics to be implemented with CAS loops.
       // The workaround here mirrors HIP and OpenMP.
       target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,

diff  --git a/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp b/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
index fd15879d7b7ea0..293738982e060c 100644
--- a/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
@@ -19,14 +19,18 @@ FailureOr<Chipset> Chipset::parse(StringRef name) {
 
   unsigned major = 0;
   unsigned minor = 0;
+  unsigned stepping = 0;
 
   StringRef majorRef = name.drop_back(2);
-  StringRef minorRef = name.take_back(2);
+  StringRef minorRef = name.take_back(2).drop_back(1);
+  StringRef steppingRef = name.take_back(1);
   if (majorRef.getAsInteger(10, major))
     return failure();
   if (minorRef.getAsInteger(16, minor))
     return failure();
-  return Chipset(major, minor);
+  if (steppingRef.getAsInteger(16, stepping))
+    return failure();
+  return Chipset(major, minor, stepping);
 }
 
 } // namespace mlir::amdgpu

diff  --git a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
index b08b6681235d3b..976ff2e7382edf 100644
--- a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
+++ b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
@@ -16,17 +16,20 @@ TEST(ChipsetTest, Parsing) {
   FailureOr<Chipset> chipset = Chipset::parse("gfx90a");
   ASSERT_TRUE(succeeded(chipset));
   EXPECT_EQ(chipset->majorVersion, 9u);
-  EXPECT_EQ(chipset->minorVersion, 0x0au);
+  EXPECT_EQ(chipset->minorVersion, 0u);
+  EXPECT_EQ(chipset->steppingVersion, 0xau);
 
   chipset = Chipset::parse("gfx940");
   ASSERT_TRUE(succeeded(chipset));
   EXPECT_EQ(chipset->majorVersion, 9u);
-  EXPECT_EQ(chipset->minorVersion, 0x40u);
+  EXPECT_EQ(chipset->minorVersion, 4u);
+  EXPECT_EQ(chipset->steppingVersion, 0u);
 
   chipset = Chipset::parse("gfx1103");
   ASSERT_TRUE(succeeded(chipset));
   EXPECT_EQ(chipset->majorVersion, 11u);
-  EXPECT_EQ(chipset->minorVersion, 0x03u);
+  EXPECT_EQ(chipset->minorVersion, 0u);
+  EXPECT_EQ(chipset->steppingVersion, 3u);
 }
 
 TEST(ChipsetTest, ParsingInvalid) {
@@ -43,14 +46,20 @@ TEST(ChipsetTest, ParsingInvalid) {
 }
 
 TEST(ChipsetTest, Comparison) {
-  EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40));
-  EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42));
-  EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00));
-
-  EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00));
-  EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42));
-  EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42));
-  EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40));
+  EXPECT_EQ(Chipset(9, 4, 0), Chipset(9, 4, 0));
+  EXPECT_NE(Chipset(9, 4, 0), Chipset(9, 4, 2));
+  EXPECT_NE(Chipset(9, 0, 0), Chipset(10, 0, 0));
+
+  EXPECT_LT(Chipset(9, 0, 0), Chipset(10, 0, 0));
+  EXPECT_LT(Chipset(9, 0, 0), Chipset(9, 4, 2));
+  EXPECT_LE(Chipset(9, 4, 1), Chipset(9, 4, 1));
+  EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 2));
+  EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 0));
+
+  EXPECT_GT(Chipset(9, 0, 0xa), Chipset(9, 0, 8));
+  EXPECT_GE(Chipset(9, 0, 0xa), Chipset(9, 0, 0xa));
+  EXPECT_FALSE(Chipset(9, 4, 1) >= Chipset(9, 4, 2));
+  EXPECT_FALSE(Chipset(9, 0, 0xa) >= Chipset(9, 4, 0));
 }
 
 } // namespace


        


More information about the Mlir-commits mailing list