[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)
Pietro Ghiglio
llvmlistbot at llvm.org
Thu Dec 12 00:49:17 PST 2024
https://github.com/PietroGhg created https://github.com/llvm/llvm-project/pull/119675
This PR adds support to the `bf16` and `i1` data types when converting `gpu::shuffle` to the `LLVMSPV` dialect, by inserting `bitcast` to/from `i16` (for `bf16`) and extending/truncating to `i8` (for `i1`).
>From 9e3af90821eeec58d715a0559ceaa61bee83b999 Mon Sep 17 00:00:00 2001
From: PietroGhg <pietro.ghiglio at codeplay.com>
Date: Mon, 9 Dec 2024 11:24:28 +0000
Subject: [PATCH] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV
conversion
---
.../Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 83 +++++++++++++++++--
.../GPUToLLVMSPV/gpu-to-llvm-spv.mlir | 19 ++++-
2 files changed, 91 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
typeMangling.value());
}
+ static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+ return getFuncName(op.getMode(), op.getType(0));
+ }
+
/// Get the subgroup size from the target or return a default.
static std::optional<int> getSubgroupSize(Operation *op) {
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
- // expected-error at below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
}
More information about the Mlir-commits
mailing list