[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)
Pietro Ghiglio
llvmlistbot at llvm.org
Tue Jan 7 08:01:33 PST 2025
================
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!hasValidWidth(op))
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
----------------
PietroGhg wrote:
Done, thank you
https://github.com/llvm/llvm-project/pull/119675
More information about the Mlir-commits
mailing list