[Mlir-commits] [mlir] [MLIR][AMDGPU] Add support for fp8 ops on gfx12 (PR #106388)

Giuseppe Rossini llvmlistbot at llvm.org
Mon Sep 2 12:03:52 PDT 2024


https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/106388

>From d5c41ed566118bdc174f03bb816a199a056a1048 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 28 Aug 2024 14:10:55 +0100
Subject: [PATCH 1/3] [MLIR][AMDGPU] Add support for fp8 ops on gfx12

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  7 +++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 37 +++++++++++--------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  4 +-
 .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir  |  9 +++++
 mlir/test/Target/LLVMIR/rocdl.mlir            | 10 +++++
 6 files changed, 50 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..35789984c92212 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..bbb6e666d82956 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
-// Available on RDNA3
+// Available from gfx11
 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
 def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
 def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
 def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
 def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
 def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..45c5070333b527 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
                                  bool isUnsigned, Value llvmInput,
+                                 Value mlirInput,
                                  SmallVector<Value, 4> &operands) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
 
+  auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
+  if (mlirInputType.getElementType().isInteger(8)) {
+    // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+    bool localIsUnsigned = isUnsigned;
+    if (elemType.isUnsignedInteger(8)) {
+      localIsUnsigned = true;
+    } else if (elemType.isSignedInteger(8)) {
+      localIsUnsigned = false;
+    }
+    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+    operands.push_back(sign);
+  }
+
   int64_t numBytes = vectorType.getNumElements();
   Type i32 = rewriter.getI32Type();
   VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
   auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
       loc, llvmVectorType32bits, llvmInput);
-
-  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
-  bool localIsUnsigned = isUnsigned;
-  if (elemType.isUnsignedInteger(8)) {
-    localIsUnsigned = true;
-  } else if (elemType.isSignedInteger(8)) {
-    localIsUnsigned = false;
-  }
-  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-  operands.push_back(sign);
   operands.push_back(result);
 }
 
@@ -601,6 +604,10 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+  } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+  } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
   }
   return std::nullopt;
 }
@@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     Location loc = op.getLoc();
     Type outType = typeConverter->convertType(op.getDestD().getType());
 
-    if (chipset.majorVersion != 11)
-      return op->emitOpError("WMMA only supported on gfx11");
+    if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+      return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
@@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), operands);
+                         adaptor.getSourceA(), op.getSourceA(), operands);
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), operands);
+                         adaptor.getSourceB(), op.getSourceB(), operands);
     wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
                           op.getSubwordOffset(), op.getClamp(), operands);
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..a8d6ccdc1a471e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() {
 
   bool isDestFloat =
       (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
-  bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+  bool isSrcFloat =
+      (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
+       sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
 
   if (isDestFloat && !isSrcFloat) {
     return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>,  %arg2 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..79f5c133503d44 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
   llvm.return %rsrc : !llvm.ptr<8>
 }
 
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+  %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+  llvm.return %r0 : vector<8 x f32>
+}
+
 llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32,

>From dbfc608b39b0e34fd7ca2b92e8c81fcb275488d5 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Thu, 29 Aug 2024 10:48:07 +0100
Subject: [PATCH 2/3] Address review feeback

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 24 ++++++++++---------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 45c5070333b527..8c739de0ab1516 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -399,8 +399,12 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
 
-  auto mlirInputType = dyn_cast<VectorType>(mlirInput.getType());
-  if (mlirInputType.getElementType().isInteger(8)) {
+  // We need to check the type of the input before conversion to properly test
+  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
+  // fp8/int8 information is lost during the conversion process.
+  auto mlirInputType = cast<VectorType>(mlirInput.getType());
+  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
+  if (isInputInt8) {
     // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
     bool localIsUnsigned = isUnsigned;
     if (elemType.isUnsignedInteger(8)) {
@@ -593,22 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   auto elemSourceType = sourceVectorType.getElementType();
   auto elemDestType = destVectorType.getElementType();
 
-  if (elemSourceType.isF16() && elemDestType.isF32()) {
+  if (elemSourceType.isF16() && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  }
-  if (elemSourceType.isBF16() && elemDestType.isF32()) {
+  if (elemSourceType.isBF16() && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  } else if (elemSourceType.isF16() && elemDestType.isF16()) {
+  if (elemSourceType.isF16() && elemDestType.isF16())
     return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
+  if (elemSourceType.isBF16() && elemDestType.isBF16())
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
+  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) {
+  if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) {
+  if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
-  }
   return std::nullopt;
 }
 

>From b0ceff8150303b7d6e1f72e1b40b1dc290abbc98 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 2 Sep 2024 20:01:53 +0100
Subject: [PATCH 3/3] Address review feedback - 2

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a8d6ccdc1a471e..1bc41ba9c8cf57 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -233,11 +233,10 @@ LogicalResult WMMAOp::verify() {
   Type sourceAElemType = sourceVectorAType.getElementType();
   Type destElemType = destVectorType.getElementType();
 
-  bool isDestFloat =
-      (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
+  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
   bool isSrcFloat =
-      (sourceAElemType.isF16() || sourceAElemType.isBF16() ||
-       sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2());
+      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
+          sourceAElemType);
 
   if (isDestFloat && !isSrcFloat) {
     return emitOpError("Expected float sources with float destination");



More information about the Mlir-commits mailing list