[Mlir-commits] [mlir] a8e1c6f - [MLIR][AMDGPU] Add support for fp8 ops on gfx12 (#106388)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 3 09:47:11 PDT 2024
Author: Giuseppe Rossini
Date: 2024-09-03T17:47:08+01:00
New Revision: a8e1c6f99abc273677afed5eaaeee2c0296db59f
URL: https://github.com/llvm/llvm-project/commit/a8e1c6f99abc273677afed5eaaeee2c0296db59f
DIFF: https://github.com/llvm/llvm-project/commit/a8e1c6f99abc273677afed5eaaeee2c0296db59f.diff
LOG: [MLIR][AMDGPU] Add support for fp8 ops on gfx12 (#106388)
This PR is adding support for `fp8` and `bfp8` on gfx12
Added:
mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Target/LLVMIR/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e5c1a53f34bf64..8a1ef94c853a58 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -552,7 +552,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index e832dfa9d6b80e..35fd8270ca6935 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -329,13 +329,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
"$args attr-dict `:` functional-type($args, $res)";
}
-// Available on RDNA3
+// Available from gfx11
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+// Available from gfx12
+def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 809e9448e80abf..7e407f1ca528d8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
+ Value mlirInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
+ // We need to check the type of the input before conversion to properly test
+ // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
+ // fp8/int8 information is lost during the conversion process.
+ auto mlirInputType = cast<VectorType>(mlirInput.getType());
+ bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
+ if (isInputInt8) {
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+ bool localIsUnsigned = isUnsigned;
+ if (elemType.isUnsignedInteger(8)) {
+ localIsUnsigned = true;
+ } else if (elemType.isSignedInteger(8)) {
+ localIsUnsigned = false;
+ }
+ Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+ operands.push_back(sign);
+ }
+
int64_t numBytes = vectorType.getNumElements();
Type i32 = rewriter.getI32Type();
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
-
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmVectorType32bits, llvmInput);
-
- // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
- bool localIsUnsigned = isUnsigned;
- if (elemType.isUnsignedInteger(8)) {
- localIsUnsigned = true;
- } else if (elemType.isSignedInteger(8)) {
- localIsUnsigned = false;
- }
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
operands.push_back(result);
}
@@ -590,18 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
auto elemSourceType = sourceVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();
- if (elemSourceType.isF16() && elemDestType.isF32()) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
- }
- if (elemSourceType.isBF16() && elemDestType.isF32()) {
+ if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
- } else if (elemSourceType.isF16() && elemDestType.isF16()) {
+ if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
- } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
- } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- }
+ if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
+ if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
return std::nullopt;
}
@@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
- if (chipset.majorVersion != 11)
- return op->emitOpError("WMMA only supported on gfx11");
+ if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
+ return op->emitOpError("WMMA only supported on gfx11 and gfx12");
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), operands);
+ adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), operands);
+ adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 3943696364950f..63447baa31eb0c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -234,9 +234,10 @@ LogicalResult WMMAOp::verify() {
Type sourceAElemType = sourceVectorAType.getElementType();
Type destElemType = destVectorType.getElementType();
- bool isDestFloat =
- (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
- bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+ bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
+ bool isSrcFloat =
+ isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
+ sourceAElemType);
if (isDestFloat && !isSrcFloat) {
return emitOpError("Expected float sources with float destination");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
new file mode 100644
index 00000000000000..7b2b524d4af426
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
+ // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+ func.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index d902a82eeb9ea2..97b505746fc751 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -377,6 +377,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
llvm.return %rsrc : !llvm.ptr<8>
}
+llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
+ %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+
+ llvm.return %r0 : vector<8 x f32>
+}
+
llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
%offset : i32, %soffset : i32,
%vdata1 : i32,
More information about the Mlir-commits
mailing list