[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled floating point conversion ops (PR #141554)
Tim Gymnich
llvmlistbot at llvm.org
Wed Jun 11 06:21:27 PDT 2025
================
@@ -1230,6 +1257,157 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
+ ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+
+ Value source = adaptor.getSource();
+ Value scale = adaptor.getScale();
+
+ VectorType sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ Type sourceElemType = getElementTypeOrSelf(op.getSource());
+ VectorType destVecType = dyn_cast<VectorType>(op.getResult().getType());
+ Type destElemType = getElementTypeOrSelf(op.getResult());
+
+ VectorType packedVecType;
+ if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
+ VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
+ packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
+ } else if (isa<Float4E2M1FNType>(sourceElemType)) {
+ VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
+ packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
+ } else {
+ llvm_unreachable("invalid element type for scaled ext");
+ }
+
+ // Extend to a packedVectorType
+ if (!sourceVecType ||
+ sourceVecType.getNumElements() < packedVecType.getNumElements()) {
+ Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ if (!sourceVecType) {
+ longVec = rewriter.create<LLVM::InsertElementOp>(
+ loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ } else {
+ for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+ Value idx = createI32Constant(rewriter, loc, i);
+ Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ longVec =
+ rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ }
+ }
+ source = longVec;
+ }
+ Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+
+ if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
+ op, destVecType, i32Source, scale, op.getIndex());
+ else
+ return failure();
----------------
tgymnich wrote:
I feel like this makes the code less readable.
https://github.com/llvm/llvm-project/pull/141554
More information about the Mlir-commits
mailing list