[Mlir-commits] [mlir] 6292ea6 - [mlir][AMDGPU] Remove an old bf16 workaround (#108409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 12 15:45:43 PDT 2024


Author: Krzysztof Drewniak
Date: 2024-09-12T17:45:39-05:00
New Revision: 6292ea6879217468cd9187d4f4dd3ee7c713431c

URL: https://github.com/llvm/llvm-project/commit/6292ea6879217468cd9187d4f4dd3ee7c713431c
DIFF: https://github.com/llvm/llvm-project/commit/6292ea6879217468cd9187d4f4dd3ee7c713431c.diff

LOG: [mlir][AMDGPU] Remove an old bf16 workaround (#108409)

The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we no
longer need to type convert MLIR's `bf16` to `i16` during lowerings to
ROCDL.

As a result of this change, we discovered that, whel the code for MFMA
and WMMA intrinsics was mainly prepared for this change, we were failing
to bitcast the bf16 results of WMMA operations out from the i16 they're
natively represented as. This commit also fixes that issue.

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c2785f34564e3b..f80d2793eaef59 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -671,18 +671,27 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Type outType = typeConverter->convertType(op.getDestD().getType());
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
+    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
+    // need to bitcast bfloats to i16 and then bitcast them back.
+    VectorType rawOutType = outType;
+    if (outType.getElementType().isBF16())
+      rawOutType = outType.clone(rewriter.getI16Type());
+
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
     OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(outType);
+    loweredOp.addTypes(rawOutType);
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +703,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     loweredOp.addOperands(operands);
     Operation *lowered = rewriter.create(loweredOp);
-    rewriter.replaceOp(op, lowered->getResults());
+
+    Operation *maybeCastBack = lowered;
+    if (rawOutType != outType)
+      maybeCastBack =
+          rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+    rewriter.replaceOp(op, maybeCastBack->getResults());
 
     return success();
   }
@@ -1033,15 +1047,6 @@ struct ConvertAMDGPUToROCDLPass
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
-  converter.addConversion([](BFloat16Type t) -> Type {
-    return IntegerType::get(t.getContext(), 16);
-  });
-  converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
-    if (!t.getElementType().isBF16())
-      return std::nullopt;
-    return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
-  });
-
   patterns
       .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
            RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 1a4ef33db2aed5..7b144809235d50 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -15,9 +15,11 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
   amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
-  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
-  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
   amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 2396ddf6b14b83..eb065cbab86789 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -507,22 +507,22 @@ gpu.module @test_module {
 
 // -----
 
-// Test that the bf16 type is lowered away on this target.
+// Test that the bf16 type is passed through to LLVM.
 
 gpu.module @test_module {
   // CHECK-LABEL: func @bf16_id
   func.func @bf16_id(%arg0 : bf16) -> bf16 {
-    // CHECK-SAME: (%[[ARG0:.+]]: i16)
-    // CHECK-SAME: -> i16
-    // CHECK: return %[[ARG0]] : i16
+    // CHECK-SAME: (%[[ARG0:.+]]: bf16)
+    // CHECK-SAME: -> bf16
+    // CHECK: return %[[ARG0]] : bf16
     func.return %arg0 : bf16
   }
 
   // CHECK-LABEL: func @bf16x4_id
   func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
-    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
-    // CHECK-SAME: -> vector<4xi16>
-    // CHECK: return %[[ARG0]] : vector<4xi16>
+    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
+    // CHECK-SAME: -> vector<4xbf16>
+    // CHECK: return %[[ARG0]] : vector<4xbf16>
     func.return %arg0 : vector<4xbf16>
   }
 


        


More information about the Mlir-commits mailing list