[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Nov 17 07:03:41 PST 2025
================
@@ -1613,6 +1613,182 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
+ int firstScaleByte) {
+ // When lowering amdgpu.scaled_ext_packed816 to
+ // rocdl.cvt.scale.pk*.f*.f* operations, the
+ // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
+ // are merged into a single attribute scaleSel.
+ //
+ // This is how those values are merged together.
+ assert(llvm::is_contained({16, 32}, blockSize));
+ assert(llvm::is_contained({4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+ return 0b000;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+ return 0b001;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+ return 0b010;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
+ return 0b011;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
+ return 0b100;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
+ return 0b101;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+ return 0b110;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+ return 0b111;
+ }
+
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+ return 0b0000;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+ return 0b0001;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
+ return 0b0010;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+ return 0b0100;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
+ return 0b0110;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
+ return 0b1010;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+ return 0b1100;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+ return 0b1101;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
+ return 0b1110;
+ }
+
+ llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
+ "blockSize and type.");
+ return 0;
+}
+
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+ ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+
+ int firstScaleLane = op.getFirstScaleLane();
+ int firstScaleByte = op.getFirstScaleByte();
+ int blockSize = op.getBlockSize();
+ auto sourceType = cast<VectorType>(op.getSource().getType());
+ auto srcElemType = cast<FloatType>(sourceType.getElementType());
+ int bitWidth = srcElemType.getWidth();
+ int scaleSel =
+ getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+ auto targetType = cast<VectorType>(op.getResult().getType());
+ auto destElemType = cast<FloatType>(targetType.getElementType());
+ Location loc = op.getLoc();
+ IntegerType i32 = rewriter.getI32Type();
+ Value castedScale =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
+ Value source = adaptor.getSource();
+ Type packedType;
+ if (isa<Float4E2M1FNType>(srcElemType)) {
+ packedType = i32;
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<Float8E4M3FNType>(srcElemType) ||
+ isa<Float8E5M2Type>(srcElemType)) {
+ packedType = VectorType::get(2, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<Float6E2M3FNType>(srcElemType) ||
+ isa<Float6E3M2FNType>(srcElemType)) {
+ packedType = VectorType::get(3, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else {
+ llvm_unreachable("invalid element type for scaled ext");
+ }
+ // smallT = [Fp4, Fp8, Bf8]
+ // Bf8 = E5M2
+ // Fp8 = E4M3
+ //
+ // largeT = [F16, Bf16, F32]
+ // CvtPkScalePk8${largeT}${smallT}
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
+
+ if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
----------------
amd-eochoalo wrote:
https://github.com/llvm/llvm-project/pull/168123/commits/33ef57e0dce2640dc8c3cf3c5623ffc71eb42d18
https://github.com/llvm/llvm-project/pull/168123
More information about the Mlir-commits
mailing list