[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)
Benoit Jacob
llvmlistbot at llvm.org
Mon Oct 7 10:44:32 PDT 2024
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/111400
>From f0909e4992a9a979b06abfaea63fecbfc134b90b Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 7 Oct 2024 11:11:13 -0500
Subject: [PATCH 1/3] bitcast
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 40 ++++++-------------
1 file changed, 12 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
-
- if (!vectorType.getElementType().isInteger(8))
- return input;
- int64_t numBytes = vectorType.getNumElements();
- Type destType = rewriter.getIntegerType(numBytes * 8);
- Value result = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, 0));
- for (int64_t i = 0; i < numBytes; ++i) {
- Value idxConst = createI32Constant(rewriter, loc, i);
- Value element =
- rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
- Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
- Value shiftConst = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, i * 8));
- Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
- result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ if (vectorType.getElementType().isInteger(8)) {
+ return rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
- return result;
}
return input;
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
- mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
>From 5febb7ea1eec75482996796dfbe6974d956bd359 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 7 Oct 2024 12:38:03 -0500
Subject: [PATCH 2/3] add test
---
mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
index 7ef9d172d52cd1..2fbc67a9804337 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 -cse | FileCheck %s
func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
%arg2 : vector<16xf32>, %arg3 : vector<4xf32>,
%arg4 : vector<4xf16>, %arg5 : vector<4xi8>,
@@ -28,7 +28,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
- // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
+ // CHECK: %[[BITCAST_4xi8_i32:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
+ // CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
// CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
@@ -38,7 +39,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
// CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
- // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: %[[BITCAST_2xbf16_2xi16:.+]] = llvm.bitcast {{.*}} : vector<2xbf16> to vector<2xi16>
+ // CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
@@ -48,7 +50,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: %[[BITCAST_4xbf16_4xi16:.+]] = llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+ // CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
@@ -62,7 +65,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64>
// CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64
- // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: %[[BITCAST_8xi8_i64:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
+ // CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
// CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
>From 1a54d624ce786bc6c0df0f255b869d078524ee79 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 7 Oct 2024 12:44:14 -0500
Subject: [PATCH 3/3] Also check for bitcast on bf8/fp8 testcases
---
mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
index 2fbc67a9804337..f8a60d37801ebe 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -74,9 +74,11 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: %[[BITCAST_8xi8_i64_1:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: %[[BITCAST_8xi8_i64_2:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
// CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
More information about the Mlir-commits
mailing list