[Mlir-commits] [mlir] [mlir][AMDGPU] Remove an old bf16 workaround (PR #108409)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Sep 12 12:08:20 PDT 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/108409

>From 5c5d3d92a8dd42f8f20fd7ea9e75ac48febadc55 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 11 Sep 2024 18:23:39 +0000
Subject: [PATCH 1/4] [mlir][AMDGPU] Remove an old bf16 workaround

The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we
no longer need to type convert MLIR's `bf16` to `i16` during lowerings
to ROCDL.

As a result of this change, we discovered that, whel the code for MFMA
and WMMA intrinsics was mainly prepared for this change, we were
failing to bitcast the bf16 results of WMMA operations out from the
i16 they're natively represented as. This commit also fixes that issue.
---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 27 ++++++++++---------
 mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir  |  1 +
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 14 +++++-----
 3 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c2785f34564e3b..31d35390a7e7f8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -671,18 +671,25 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Type outType = typeConverter->convertType(op.getDestD().getType());
+    auto outType =
+        cast<VectorType>(typeConverter->convertType(op.getDestD().getType()));
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
+    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
+    // need to bitcast bfloats to i16 and then bitcast them back.
+    VectorType rawOutType = outType;
+    if (outType.getElementType().isBF16())
+      rawOutType = outType.clone(rewriter.getI16Type());
+
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
     OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(outType);
+    loweredOp.addTypes(rawOutType);
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +701,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     loweredOp.addOperands(operands);
     Operation *lowered = rewriter.create(loweredOp);
-    rewriter.replaceOp(op, lowered->getResults());
+
+    Operation *maybeCastBack = lowered;
+    if (rawOutType != outType)
+      maybeCastBack =
+          rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+    rewriter.replaceOp(op, maybeCastBack->getResults());
 
     return success();
   }
@@ -1033,15 +1045,6 @@ struct ConvertAMDGPUToROCDLPass
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
-  converter.addConversion([](BFloat16Type t) -> Type {
-    return IntegerType::get(t.getContext(), 16);
-  });
-  converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
-    if (!t.getElementType().isBF16())
-      return std::nullopt;
-    return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
-  });
-
   patterns
       .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
            RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 1a4ef33db2aed5..9ca89a0babd951 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -16,6 +16,7 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
   amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 56b65beb036954..3fa9fa5e935d2e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -445,22 +445,22 @@ gpu.module @test_module {
 
 // -----
 
-// Test that the bf16 type is lowered away on this target.
+// Test that the bf16 type is passed through to LLVM.
 
 gpu.module @test_module {
   // CHECK-LABEL: func @bf16_id
   func.func @bf16_id(%arg0 : bf16) -> bf16 {
-    // CHECK-SAME: (%[[ARG0:.+]]: i16)
-    // CHECK-SAME: -> i16
-    // CHECK: return %[[ARG0]] : i16
+    // CHECK-SAME: (%[[ARG0:.+]]: bf16)
+    // CHECK-SAME: -> bf16
+    // CHECK: return %[[ARG0]] : bf16
     func.return %arg0 : bf16
   }
 
   // CHECK-LABEL: func @bf16x4_id
   func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
-    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
-    // CHECK-SAME: -> vector<4xi16>
-    // CHECK: return %[[ARG0]] : vector<4xi16>
+    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
+    // CHECK-SAME: -> vector<4xbf16>
+    // CHECK: return %[[ARG0]] : vector<4xbf16>
     func.return %arg0 : vector<4xbf16>
   }
 

>From 5f842ca428bc401fda31228f32e828a45d1d5b37 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 12 Sep 2024 18:12:45 +0000
Subject: [PATCH 2/4] Review comments

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 5 ++++-
 mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir        | 7 ++++---
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 31d35390a7e7f8..42d22d9dfae24e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -672,7 +672,10 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     auto outType =
-        cast<VectorType>(typeConverter->convertType(op.getDestD().getType()));
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(
+          op, "wmma output doesn't convert to a vector for no clear reason");
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 9ca89a0babd951..7b144809235d50 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -15,10 +15,11 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
   amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
-  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16>
+  // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
-  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
   amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>

>From d7f7973804d5fd38bcb0b8253ee949f707e71de0 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 12 Sep 2024 14:02:26 -0500
Subject: [PATCH 3/4] Shorten error message on type conversion failure

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 42d22d9dfae24e..28991b6b9ed2bb 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -675,7 +675,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
         typeConverter->convertType<VectorType>(op.getDestD().getType());
     if (!outType)
       return rewriter.notifyMatchFailure(
-          op, "wmma output doesn't convert to a vector for no clear reason");
+          op, "type conversion failed");
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");

>From 2dfd30bdef32976282a93dd8e4b1c7fc53d9381f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 12 Sep 2024 19:08:00 +0000
Subject: [PATCH 4/4] Fix clang-format from suggestion

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 28991b6b9ed2bb..f80d2793eaef59 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -674,8 +674,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     auto outType =
         typeConverter->convertType<VectorType>(op.getDestD().getType());
     if (!outType)
-      return rewriter.notifyMatchFailure(
-          op, "type conversion failed");
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");



More information about the Mlir-commits mailing list