[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)
Pietro Ghiglio
llvmlistbot at llvm.org
Wed Jan 8 05:37:30 PST 2025
https://github.com/PietroGhg updated https://github.com/llvm/llvm-project/pull/119675
>From 60ef4fd1b621cbebefac2c6916d937e30138cd5b Mon Sep 17 00:00:00 2001
From: PietroGhg <pietro.ghiglio at codeplay.com>
Date: Mon, 9 Dec 2024 11:24:28 +0000
Subject: [PATCH] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV
conversion
---
.../Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 56 ++++++++++++++++---
.../GPUToLLVMSPV/gpu-to-llvm-spv.mlir | 19 ++++++-
2 files changed, 63 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index a68c0153df4432..8b6b553f6eed05 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,12 +262,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
- return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
+ return llvm::formatv("_Z{}{}{}", baseName.size(), baseName,
typeMangling.value());
}
@@ -286,6 +287,37 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case([&](BFloat16Type) {
+ return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
+ oldVal);
+ })
+ .Case([&](IntegerType intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
+ oldVal);
+ return oldVal;
+ })
+ .Default(oldVal);
+ }
+
+ static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
+ Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case([&](BFloat16Type) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case([&](IntegerType intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return oldVal;
+ })
+ .Default(oldVal);
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,26 +325,32 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Value inValue =
+ bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
+ std::optional<std::string> funcName =
+ getFuncName(op.getMode(), inValue.getType());
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ Value resultOrConversion =
+ bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
- rewriter.replaceOp(op, {result, trueVal});
+ rewriter.replaceOp(op, {resultOrConversion, trueVal});
return success();
}
};
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index e75225d6d54f55..c2930971dbcf9b 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -279,7 +279,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -287,6 +288,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -303,6 +306,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -310,6 +321,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -344,10 +357,10 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%val: f128, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
// expected-error at below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
+ %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : f128
llvm.return
}
}
More information about the Mlir-commits
mailing list