[Mlir-commits] [mlir] [mlir][AMDGPU] Remove an old bf16 workaround (PR #108409)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Sep 12 08:40:01 PDT 2024
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/108409
The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we no longer need to type convert MLIR's `bf16` to `i16` during lowerings to ROCDL.
As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue.
>From 5c5d3d92a8dd42f8f20fd7ea9e75ac48febadc55 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 11 Sep 2024 18:23:39 +0000
Subject: [PATCH] [mlir][AMDGPU] Remove an old bf16 workaround
The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we
no longer need to type convert MLIR's `bf16` to `i16` during lowerings
to ROCDL.
As a result of this change, we discovered that, whel the code for MFMA
and WMMA intrinsics was mainly prepared for this change, we were
failing to bitcast the bf16 results of WMMA operations out from the
i16 they're natively represented as. This commit also fixes that issue.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 27 ++++++++++---------
mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir | 1 +
.../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 14 +++++-----
3 files changed, 23 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c2785f34564e3b..31d35390a7e7f8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -671,18 +671,25 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Type outType = typeConverter->convertType(op.getDestD().getType());
+ auto outType =
+ cast<VectorType>(typeConverter->convertType(op.getDestD().getType()));
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
+ // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
+ // need to bitcast bfloats to i16 and then bitcast them back.
+ VectorType rawOutType = outType;
+ if (outType.getElementType().isBF16())
+ rawOutType = outType.clone(rewriter.getI16Type());
+
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(outType);
+ loweredOp.addTypes(rawOutType);
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +701,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
loweredOp.addOperands(operands);
Operation *lowered = rewriter.create(loweredOp);
- rewriter.replaceOp(op, lowered->getResults());
+
+ Operation *maybeCastBack = lowered;
+ if (rawOutType != outType)
+ maybeCastBack =
+ rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
}
@@ -1033,15 +1045,6 @@ struct ConvertAMDGPUToROCDLPass
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
- converter.addConversion([](BFloat16Type t) -> Type {
- return IntegerType::get(t.getContext(), 16);
- });
- converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
- if (!t.getElementType().isBF16())
- return std::nullopt;
- return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
- });
-
patterns
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 1a4ef33db2aed5..9ca89a0babd951 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -16,6 +16,7 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 56b65beb036954..3fa9fa5e935d2e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -445,22 +445,22 @@ gpu.module @test_module {
// -----
-// Test that the bf16 type is lowered away on this target.
+// Test that the bf16 type is passed through to LLVM.
gpu.module @test_module {
// CHECK-LABEL: func @bf16_id
func.func @bf16_id(%arg0 : bf16) -> bf16 {
- // CHECK-SAME: (%[[ARG0:.+]]: i16)
- // CHECK-SAME: -> i16
- // CHECK: return %[[ARG0]] : i16
+ // CHECK-SAME: (%[[ARG0:.+]]: bf16)
+ // CHECK-SAME: -> bf16
+ // CHECK: return %[[ARG0]] : bf16
func.return %arg0 : bf16
}
// CHECK-LABEL: func @bf16x4_id
func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
- // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
- // CHECK-SAME: -> vector<4xi16>
- // CHECK: return %[[ARG0]] : vector<4xi16>
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
+ // CHECK-SAME: -> vector<4xbf16>
+ // CHECK: return %[[ARG0]] : vector<4xbf16>
func.return %arg0 : vector<4xbf16>
}
More information about the Mlir-commits
mailing list