[Mlir-commits] [mlir] [mlir][AMDGPU] implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op (PR #141554)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 26 23:55:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Tim Gymnich (tgymnich)
<details>
<summary>Changes</summary>
implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op
---
Patch is 20.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141554.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+56)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+130-1)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+6)
- (added) mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir (+108)
- (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 02308568c1ad1..301705bd1786b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
}];
}
+def AMDGPU_ScaledExtPackedFp8Op :
+ AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
+ Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+ Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+ let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
+ let description = [{
+ Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
+ two floats and return them.
+
+ This rather unusual signature arises from the fact that AMD GPUs cannot
+ easily work with sub 32-bit quantities, so the compiler intrinsics for
+ extending 8-bit floats (which are, currently, the only way to work with
+ this operation) take packed vectors of 2 such floats.
+
+ If the passed-in vector has fewer than two elements, or the input is scalar,
+ the remaining values in the <2 x i8> will be filled with
+ undefined values as needed.
+ }];
+ let assemblyFormat = [{
+ attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
+ }];
+}
+
def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_PackedScaledTrunc2xFp8Op :
+ AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+ Arguments<(ins F32:$sourceA,
+ Optional<F32>:$sourceB,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+ Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
+ let summary = "Round two floats into a packed vector of 8-bit floats";
+ let description = [{
+ Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
+ specified) into the low or high word (bottom two or top two) elements
+ of the returned vector, keeping the other two elements of `existing`
+ unchanged if present (or undefined if it was not passed in).
+
+ The reason for this odd signature is that AMD GPUs cannot easily work with
+ sub-registers, and so the conversion intrinsics (which are currently the
+ only way to work with 8-bit float types) take packed vectors of 4 8-bit
+ values.
+ }];
+ let assemblyFormat = [{
+ attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+ `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+ `,` $scale
+ `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..5fc8e370ac4c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPackedFp8OpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
+ ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct PackedScaledTrunc2xFp8OpLowering final
+ : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
+ PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PackedScaledTrunc2xFp8Op op,
+ PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
return success();
}
+// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
+// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+
+// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
+// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
+// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
+// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
+LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
+ ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v4i8 =
+ getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+ Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+ Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+ Value source = adaptor.getSource();
+ Value scale = adaptor.getScale();
+ auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
+ Type sourceElemType = getElementTypeOrSelf(op.getSource());
+ // Extend to a v4i8
+ if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+ Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ if (!sourceVecType) {
+ longVec = rewriter.create<LLVM::InsertElementOp>(
+ loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ } else {
+ for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+ Value idx = createI32Constant(rewriter, loc, i);
+ Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ longVec =
+ rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ }
+ }
+ source = longVec;
+ }
+ Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ if (resultVecType) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ } else {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ }
+ return success();
+}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
return success();
}
+// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
+// %old[false]: vector<2xi16> : vector<2xi16>
+LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
+ PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v2i16 = getTypeConverter()->convertType(
+ VectorType::get(2, rewriter.getI16Type()));
+
+ Type resultType = op.getResult().getType();
+ Type resultElemType = getElementTypeOrSelf(resultType);
+
+ Value sourceA = adaptor.getSourceA();
+ Value sourceB = adaptor.getSourceB();
+ Value scale = adaptor.getScale();
+ if (!sourceB)
+ sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ Value existing = adaptor.getExisting();
+ if (existing)
+ existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
+ else
+ existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
+
+ Value result;
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+
+ result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+ op, getTypeConverter()->convertType(resultType), result);
+ return success();
+}
+
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
+ PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a0a98a4e86721..b24a185d21180 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
return success();
}
+LogicalResult PackedScaledTrunc2xFp8Op::verify() {
+ if (getExisting() && getExisting().getType() != getResult().getType())
+ return emitOpError("existing values must have same type as result");
+ return success();
+}
+
LogicalResult PackedStochRoundFp8Op::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
new file mode 100644
index 0000000000000..128b8eabd76cd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+// CHECK-LABEL: func @scaled_ext_scalar
+// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_short_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_full_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_2xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_4xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @packed_scaled_trunc
+// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+ func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 188cfcc4eb38b..d1d56bd3b5178 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -18,6 +18,20 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
func.return %ret : vector<2xf32>
}
+// CHECK-LABEL: func @scaled_ext_packed_fp8_s
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
+func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_fp8_v
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
+func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc_2xfp8
// CHECK: amdgpu.packed_trunc_2xfp8
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
@@ -25,6 +39,13 @@ func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>,
func.return %ret : vector<4xf8E5M2FNUZ>
}
+// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
+// ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/141554
More information about the Mlir-commits
mailing list