[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Nov 17 11:45:43 PST 2025
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/168123
>From 7b64631ffa0f4e2880d0a839443e450120da7a68 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 13:23:20 -0500
Subject: [PATCH 01/43] Update documentation
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 35 +++++++++++++------
1 file changed, 25 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 45cb67f0eee4a..4820b7a747ac2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -127,7 +127,7 @@ def AMDGPU_ScaledExtPacked816Op
FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
- ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
+ ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<3>]>:$firstScaleByte)>,
Results<(
outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
FixedVectorOfShapeAndType<[8], F16>,
@@ -139,17 +139,21 @@ def AMDGPU_ScaledExtPacked816Op
let summary = "Extend a vector of packed floating point values";
let description = [{
- The scales applied to the input microfloats are stored in two bytes which
+ The scales applied to the input microfloats are stored in bytes which
come from the `scales` input provided in a *half* of the wave identified
- by `firstScaleLane`. The pair of bytes used is selected by
- `firstScaleByte`. The 16 vectors in consecutive lanes starting from
+ by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends
+ on the type of `source`. The 16 vectors in consecutive lanes starting from
`firstScaleLane` (which we'll call the scale vectors) will be used by both
- halves of the wave (with lane L reading from L % 16'th scale vector), but
- each half will use a different byte.
+ halves of the wave (with lane L reading from L % 16'th scale vector).
+
+ When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the
+ wave will use a different byte. The first one being `firstScaleByte` and
+ the second one being `firstScaleByte` + 1. When the block size is 32,
+ `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors.
+ Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read
+ from `firstScaleByte` + 1.
+
- When the block size is 32, `firstScaleByte` can be either 0 or 2,
- selecting halves of the scale vectors. Lanes 0-15 will read from
- `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
For example:
```mlir
// Input: 8-element vector of F8E4M3FN, converting to F32
@@ -165,7 +169,8 @@ def AMDGPU_ScaledExtPacked816Op
: vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
```
- However, when the block size is 16, `firstScaleByte` can be 0 or 1.
+ When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and
+ the block size is 16, `firstScaleByte` can be 0 or 1.
Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
while lanes 16-31 read from `firstScaleByte` + 2.
For example:
@@ -187,6 +192,16 @@ def AMDGPU_ScaledExtPacked816Op
instructions use for matix scales. These selection operands allows
one to choose portions of the matrix to convert.
+ When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32,
+ then the same byte will be used by both halves of the wave.
+ In this case, `firstScaleByte` can be any value from 0 to 3.
+
+ When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16,
+ following combinations are allowed:
+ * `firstScaleLane(0), firstScaleByte(0)`
+ * `firstScaleLane(1), firstScaleByte(2)`
+ all other combinations are reserved.
+
Available on gfx1250+.
}];
>From 08e96b19369451dd5ec4e72ed2905bd0b2e0cf71 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 13:44:08 -0500
Subject: [PATCH 02/43] Fix verifiers
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 23 +++++++++++++++-----
mlir/test/Dialect/AMDGPU/invalid.mlir | 20 ++++++++++++-----
2 files changed, 32 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index df955fc90b45f..5c35823678576 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -344,14 +344,27 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
LogicalResult ScaledExtPacked816Op::verify() {
int blockSize = getBlockSize();
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+
int firstScaleByte = getFirstScaleByte();
- if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 16 can only have firstScaleByte be 0 or 1.");
+ auto sourceType = cast<VectorType>(getSource().getType());
+ Type elementType = sourceType.getElementType();
+ auto floatType = cast<FloatType>(elementType);
+ int bitWidth = floatType.getWidth();
+
+ if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
+ !llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
+ "for f4 and f6.");
+ }
+ if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
+ !llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
+ "for f4 and f6.");
}
- if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
+ if (bitWidth == 8 && blockSize == 16 &&
+ !llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError(
- "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+ "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
}
return success();
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 4c6f62a045405..5c8cc8b67c4b3 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -333,17 +333,25 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
func.return
}
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
func.return
}
>From d0932cc2935840b7e86dc700f0ec056ccabebdea Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 11:13:57 -0500
Subject: [PATCH 03/43] [mlir][amdgpu] Convert scaled_ext_packed816 to rocdl
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 28 ++++++++++++++++---
1 file changed, 24 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..48c1b17a2203a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1492,6 +1492,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPacked816OpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
+ ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1600,6 +1613,12 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+ ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ return failure();
+}
+
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -2138,9 +2157,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering, TransposeLoadOpLowering,
+ AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
>From 0d1d668762b534c90f2f838788b495c09b872d81 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 15:43:41 -0500
Subject: [PATCH 04/43] Create skeleton for pattern
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 120 ++++++++++++++++++
1 file changed, 120 insertions(+)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 48c1b17a2203a..568013bee5ec8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1613,9 +1613,129 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
+ int firstScaleByte) {
+ // When lowering amdgpu.scaled_ext_packed816 to
+ // rocdl.cvt.scale.pk*.f*.f* operations, the
+ // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
+ // are merged into a single attribute scaleSel.
+ //
+ // This is how those values are merged together.
+ assert(llvm::is_contained({16, 32}, blockSize));
+ assert(llvm::is_contained({4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+ return 0b000;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+ return 0b001;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+ return 0b010;
+ }
+ if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
+ return 0b011;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
+ return 0b100;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
+ return 0b101;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+ return 0b110;
+ }
+ if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+ return 0b111;
+ }
+
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+ return 0b0000;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+ return 0b0001;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
+ return 0b0010;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+ return 0b0100;
+ }
+ if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
+ return 0b0110;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
+ return 0b1010;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+ return 0b1100;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+ return 0b1101;
+ }
+ if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
+ return 0b1110;
+ }
+
+ llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
+ "blockSize and type.");
+ return 0;
+}
+
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+
+ int firstScaleLane = op.getFirstScaleLane();
+ int firstScaleByte = op.getFirstScaleByte();
+ int blockSize = op.getBlockSize();
+ auto sourceType = cast<VectorType>(op.getSource().getType());
+ auto srcElemType = cast<FloatType>(sourceType.getElementType());
+ int bitWidth = srcElemType.getWidth();
+ int scaleSel =
+ getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+ auto targetType = cast<VectorType>(op.getResult().getType());
+ auto tgtElemType = cast<FloatType>(targetType.getElementType());
+ Location loc = op.getLoc();
+ // Ok, so we need to construct ops depending on the sourceType and targetType.
+ // smallT = [Fp4, Fp8, Bf8]
+ // largeT = [F16, Bf16, F32]
+ // CvtPkScalePk{8}${largeT}${smallT}
+ if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ ROCDL::CvtPkScalePk8F16Fp4Op::create(
+ rewriter, loc, op.getResult().getType(), adaptor.getSource(),
+ adaptor.getScale(), scaleSel);
+ return success();
+ }
+ /*
+ CvtPkScalePk8F16Fp8Op
+ CvtPkScalePk8F16Bf8Op
+
+ CvtPkScalePk8Bf16Fp4Op
+ CvtPkScalePk8Bf16Fp8Op
+ CvtPkScalePk8Bf16Bf8Op
+
+ CvtPkScalePk8F32Fp4Op
+ CvtPkScalePk8F32Fp8Op
+ CvtPkScalePk8F32Bf8Op
+
+ // smallT = [Fp6, Bf6]
+ // largeT = [F16, Bf16, F32]
+ // CvtPkScalePk{16}${largeT}${smallT}
+ CvtPkScale16F16Fp6Op
+ CvtPkScale16F16Bf6Op
+
+ CvtPkScale16Bf16Fp6Op
+ CvtPkScale16Bf16Bf6Op
+
+ CvtPkScale16F32Fp6Op
+ CvtPkScale16F32Bf6Op
+ */
+
return failure();
}
>From 163b15aef1795ed5d30e627a7f781406b00a37d2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 16:15:34 -0500
Subject: [PATCH 05/43] Initial conversion
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 568013bee5ec8..d93332a3f3c40 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1701,14 +1701,24 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
auto targetType = cast<VectorType>(op.getResult().getType());
auto tgtElemType = cast<FloatType>(targetType.getElementType());
Location loc = op.getLoc();
+ // %scale: vector<4xf8E8M0FNU>
+ // ===========================
+ // %scale: i32
+ IntegerType i32 = rewriter.getI32Type();
+ Value castedScale =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
// Ok, so we need to construct ops depending on the sourceType and targetType.
// smallT = [Fp4, Fp8, Bf8]
// largeT = [F16, Bf16, F32]
// CvtPkScalePk{8}${largeT}${smallT}
if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
- ROCDL::CvtPkScalePk8F16Fp4Op::create(
- rewriter, loc, op.getResult().getType(), adaptor.getSource(),
- adaptor.getScale(), scaleSel);
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
return success();
}
/*
>From 6d7e2a65171ae253430067e3db23b94ca6a29c27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 16:26:39 -0500
Subject: [PATCH 06/43] Add first test
---
.../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 2fd3df6dcfa71..840187d1f36d7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -456,3 +456,16 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}
+
+// CHECK-LABEL: @scaled_ext_packed816_fp4
+// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return %ret0: vector<8xf16>
+}
+
>From fc5d8587ac8f3ea9142d3810d1dfea7ae42b4ff1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 17:18:58 -0500
Subject: [PATCH 07/43] Adds two more cases
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 28 +++++++++++++++++++
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 25 ++++++++++++++++-
2 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d93332a3f3c40..605bcf38204ba 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,6 +1721,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // vector<8xf8E4M3FN>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+ if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // vector<8xf8E5M2>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
CvtPkScalePk8F16Fp8Op
CvtPkScalePk8F16Bf8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 840187d1f36d7..8edd6c038af1e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -458,7 +458,7 @@ func.func @sched_barrier() {
}
// CHECK-LABEL: @scaled_ext_packed816_fp4
-// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
// CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
@@ -469,3 +469,26 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
func.return %ret0: vector<8xf16>
}
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return %ret0 : vector<8xf16>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return %ret0 : vector<8xf16>
+}
>From d9a254f629fe430c961ef407e2d529a1a10f691a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:14:28 -0500
Subject: [PATCH 08/43] Add case for pk8.bf16.fp4
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 18 +++++++++++++++---
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 11 ++++++++---
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 605bcf38204ba..5fd58c2b2906b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1713,6 +1713,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// largeT = [F16, Bf16, F32]
// CvtPkScalePk{8}${largeT}${smallT}
if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // CvtPkScalePk8F16Fp4Op
+ // i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
@@ -1722,6 +1724,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
return success();
}
if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // CvtPkScalePk8F16Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1736,6 +1739,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
return success();
}
if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // CvtPkScalePk8F16Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1749,11 +1753,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float4E2M1FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ // CvtPkScalePk8Bf16Fp4Op
+ // i32
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScalePk8F16Fp8Op
- CvtPkScalePk8F16Bf8Op
- CvtPkScalePk8Bf16Fp4Op
CvtPkScalePk8Bf16Fp8Op
CvtPkScalePk8Bf16Bf8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8edd6c038af1e..d2099d2f60eff 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -459,14 +459,19 @@ func.func @sched_barrier() {
// CHECK-LABEL: @scaled_ext_packed816_fp4
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
// CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
// CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
- // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return %ret0: vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+ func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
}
// CHECK-LABEL: @scaled_ext_packed816_fp8
>From c5eb6989affbd32c2100909943e790af851a5a43 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:28:50 -0500
Subject: [PATCH 09/43] Add conversion for pk8.bf16.bf8
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++-
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 31 ++++++++++++-------
2 files changed, 34 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5fd58c2b2906b..ede420606b7b4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1764,9 +1764,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ // CvtPkScalePk8Bf16Fp8Op
+ // vector<8xf8E5M2>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScalePk8Bf16Fp8Op
CvtPkScalePk8Bf16Bf8Op
CvtPkScalePk8F32Fp4Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index d2099d2f60eff..e6b2ef7bc3f79 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -474,6 +474,23 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
}
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+ func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+}
+
// CHECK-LABEL: @scaled_ext_packed816_fp8
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
@@ -481,19 +498,9 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
// CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
// CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return %ret0 : vector<8xf16>
-}
-// CHECK-LABEL: @scaled_ext_packed816_bf8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
- // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
func.return %ret0 : vector<8xf16>
}
+
>From cec5f045c639adeb2fda944751768b69fd9a5b9d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:41:25 -0500
Subject: [PATCH 10/43] Fix and add new case
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 24 ++++++++++++--
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 32 +++++++++++--------
2 files changed, 40 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ede420606b7b4..8a55f41e8495c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1710,8 +1710,12 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// Ok, so we need to construct ops depending on the sourceType and targetType.
// smallT = [Fp4, Fp8, Bf8]
+ // Bf8 = E5M2
+ // Fp8 = E4M3
+ //
// largeT = [F16, Bf16, F32]
// CvtPkScalePk{8}${largeT}${smallT}
+
if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Fp4Op
// i32
@@ -1764,9 +1768,9 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
- if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ if (isa<Float8E4M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Fp8Op
- // vector<8xf8E5M2>
+ // vector<8xf8E4M3FN>
Value source = adaptor.getSource();
// vector<2xi32>
@@ -1779,9 +1783,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ // CvtPkScalePk8Bf16Bf8Op
+ // vector<8xf8E5M2>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScalePk8Bf16Bf8Op
CvtPkScalePk8F32Fp4Op
CvtPkScalePk8F32Fp8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index e6b2ef7bc3f79..e248856ed472e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -474,6 +474,24 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
}
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+}
+
// CHECK-LABEL: @scaled_ext_packed816_bf8
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
@@ -486,21 +504,9 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
// CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
}
-// CHECK-LABEL: @scaled_ext_packed816_fp8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
- // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
- func.return %ret0 : vector<8xf16>
-}
>From 7dc34425abb1ff3271b039cbd95d008531c6a940 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:45:32 -0500
Subject: [PATCH 11/43] Add case for pk8.f32.fp4
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +++++++++++--
.../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 18 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8a55f41e8495c..5c5aad91cd8e0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1798,9 +1798,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float4E2M1FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ // CvtPkScalePk8F32Fp4Op
+ // i32
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
-
-
CvtPkScalePk8F32Fp4Op
CvtPkScalePk8F32Fp8Op
CvtPkScalePk8F32Bf8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index e248856ed472e..13f109b787d4d 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -459,7 +459,7 @@ func.func @sched_barrier() {
// CHECK-LABEL: @scaled_ext_packed816_fp4
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
// CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -471,7 +471,12 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
// CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
// CHECK-LABEL: @scaled_ext_packed816_fp8
>From 0ba6b949e626bbafdb5fdc5b94cea342b4e7aab5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:50:01 -0500
Subject: [PATCH 12/43] Add case for pk8.f32.fp8
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 17 +++++++++++++++--
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5c5aad91cd8e0..a347a822aba66 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1809,9 +1809,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float8E4M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ // CvtPkScalePk8F32Fp8Op
+ // vector<8xf8E4M3FN>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScalePk8F32Fp4Op
- CvtPkScalePk8F32Fp8Op
CvtPkScalePk8F32Bf8Op
// smallT = [Fp6, Bf6]
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 13f109b787d4d..917d25b9bc8f2 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -481,7 +481,7 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
// CHECK-LABEL: @scaled_ext_packed816_fp8
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
// CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -494,7 +494,12 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
// CHECK-LABEL: @scaled_ext_packed816_bf8
>From 551849e8496343a293005c32e9f3c6cc6429eba6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:54:18 -0500
Subject: [PATCH 13/43] Add case for pk8.f32.bf8
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++++++-
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 22 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a347a822aba66..72617f76e6c91 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1824,8 +1824,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float8E5M2Type>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ // CvtPkScalePk8F32Bf8Op
+ // vector<8xf8E5M2>
+ Value source = adaptor.getSource();
+
+ // vector<2xi32>
+ VectorType v2xi32 = VectorType::get(2, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScalePk8F32Bf8Op
// smallT = [Fp6, Bf6]
// largeT = [F16, Bf16, F32]
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 917d25b9bc8f2..220759b6fab39 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -504,7 +504,7 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
// CHECK-LABEL: @scaled_ext_packed816_bf8
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
// CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
// CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -516,7 +516,12 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
// CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
// CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
>From c1e10c8e4b9b6a6ba9ce210aabc7bc97a5608cab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:26:48 -0500
Subject: [PATCH 14/43] Add case for pk16.f16.bf6
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 38 ++++++++++++++++---
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 26 +++++++++++++
2 files changed, 58 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 72617f76e6c91..af068fbfd957f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1714,7 +1714,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// Fp8 = E4M3
//
// largeT = [F16, Bf16, F32]
- // CvtPkScalePk{8}${largeT}${smallT}
+ // CvtPkScalePk8${largeT}${smallT}
if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Fp4Op
@@ -1839,13 +1839,39 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
- /*
-
// smallT = [Fp6, Bf6]
// largeT = [F16, Bf16, F32]
- // CvtPkScalePk{16}${largeT}${smallT}
- CvtPkScale16F16Fp6Op
- CvtPkScale16F16Bf6Op
+ //
+ // Fp6 = Float6E2M3FN
+ // Bf6 = Float6E3M2FN
+
+ // CvtPkScalePk16${largeT}${smallT}
+ if (isa<Float6E2M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // CvtPkScale16F16Fp6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+ if (isa<Float6E3M2FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ // CvtPkScale16F16Bf6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+
+ /*
CvtPkScale16Bf16Fp6Op
CvtPkScale16Bf16Bf6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 220759b6fab39..349440053a646 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -525,3 +525,29 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
}
+// CHECK-LABEL: @scaled_ext_packed816_fp6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ return %ret0: vector<16xf16>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ return %ret0: vector<16xf16>
+}
+
>From e958d56ca7e5fa32767a10186fe17905adf11ecf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:30:17 -0500
Subject: [PATCH 15/43] Add case for pk16.bf16.fp6
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
.../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index af068fbfd957f..cbe3251268f78 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1870,10 +1870,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float6E2M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ // CvtPkScale16Bf16Fp6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScale16Bf16Fp6Op
CvtPkScale16Bf16Bf6Op
CvtPkScale16F32Fp6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 349440053a646..76e26e0cab40a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
// CHECK-LABEL: @scaled_ext_packed816_fp6
// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
// CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
@@ -535,7 +535,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
// CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
// CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
- return %ret0: vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+ return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
}
// CHECK-LABEL: @scaled_ext_packed816_bf6
>From 1f79bdd99637a7be7765df808c3e3a63199ca524 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:38:08 -0500
Subject: [PATCH 16/43] Add case for pk16.bf16.bf6
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
.../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index cbe3251268f78..322c8efcf3aea 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1882,10 +1882,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float6E3M2FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ // CvtPkScale16Bf16Bf6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScale16Bf16Bf6Op
CvtPkScale16F32Fp6Op
CvtPkScale16F32Bf6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 76e26e0cab40a..b31116538228e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -545,7 +545,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
// CHECK-LABEL: @scaled_ext_packed816_bf6
// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
// CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
@@ -553,6 +553,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
// CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
// CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
- return %ret0: vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+ return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
}
>From c5628e6ecf7a101a247da33fb55a497cfe33e98c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:44:19 -0500
Subject: [PATCH 17/43] Add case for pk16.f32.fp6
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
.../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 322c8efcf3aea..fb189b24b29e7 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1894,11 +1894,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float6E2M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ // CvtPkScale16F32Fp6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+ auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
/*
- CvtPkScale16F32Fp6Op
CvtPkScale16F32Bf6Op
*/
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index b31116538228e..dccdb81033738 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
// CHECK-LABEL: @scaled_ext_packed816_fp6
// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
// CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
@@ -540,7 +540,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
// CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
// CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
- return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
}
// CHECK-LABEL: @scaled_ext_packed816_bf6
>From db56c98bfa77ab923d4cfed8ddbb4ed15863b690 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:48:29 -0500
Subject: [PATCH 18/43] Add case for pk16.f32.bf6
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++-----
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++--
2 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index fb189b24b29e7..26282ec4c2279 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1906,12 +1906,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOp(op, newOp);
return success();
}
+ if (isa<Float6E3M2FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ // CvtPkScale16F32Bf6Op
+ Value source = adaptor.getSource();
+ VectorType v3xi32 = VectorType::get(3, i32);
+ Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- /*
-
-
- CvtPkScale16F32Bf6Op
- */
+ auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create(
+ rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+ scaleSel);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
return failure();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index dccdb81033738..94a04d98004c7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -550,7 +550,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
// CHECK-LABEL: @scaled_ext_packed816_bf6
// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
// CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
// CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
@@ -563,6 +563,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
// CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
// CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
- return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
}
>From 73ed4b7b7d1562a0aa50fa7caec5c15426c57c27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:52:19 -0500
Subject: [PATCH 19/43] Refactor NFC
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 74 ++++++++-----------
1 file changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 26282ec4c2279..44393033ef442 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1725,9 +1725,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and
+ isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1740,9 +1739,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and
+ isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1755,9 +1753,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float4E2M1FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) and
+ isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Fp4Op
// i32
Value castedSource =
@@ -1766,9 +1763,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E4M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and
+ isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1781,9 +1777,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and
+ isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1796,9 +1791,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float4E2M1FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) and
+ isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Fp4Op
// i32
Value castedSource =
@@ -1807,9 +1801,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E4M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and
+ isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1822,9 +1815,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float8E5M2Type>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and
+ isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
}
// smallT = [Fp6, Bf6]
// largeT = [F16, Bf16, F32]
@@ -1846,7 +1837,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// Bf6 = Float6E3M2FN
// CvtPkScalePk16${largeT}${smallT}
- if (isa<Float6E2M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ else if (isa<Float6E2M3FNType>(srcElemType) and
+ isa<Float16Type>(tgtElemType)) {
// CvtPkScale16F16Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1856,9 +1848,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float6E3M2FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and
+ isa<Float16Type>(tgtElemType)) {
// CvtPkScale16F16Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1868,9 +1859,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float6E2M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) and
+ isa<BFloat16Type>(tgtElemType)) {
// CvtPkScale16Bf16Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1880,9 +1870,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float6E3M2FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and
+ isa<BFloat16Type>(tgtElemType)) {
// CvtPkScale16Bf16Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1892,9 +1881,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float6E2M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) and
+ isa<Float32Type>(tgtElemType)) {
// CvtPkScale16F32Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1904,9 +1892,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
- }
- if (isa<Float6E3M2FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and
+ isa<Float32Type>(tgtElemType)) {
// CvtPkScale16F32Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1916,10 +1903,11 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter, loc, op.getResult().getType(), castedSource, castedScale,
scaleSel);
rewriter.replaceOp(op, newOp);
- return success();
+ } else {
+ return failure();
}
- return failure();
+ return success();
}
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
>From 94cc74070c0d91382c3444d96e837eb17f6afc52 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 11:01:43 -0500
Subject: [PATCH 20/43] Refactor NFC
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 90 +++++++------------
1 file changed, 30 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 44393033ef442..e28b53d8ed1a5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,10 +1721,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
- auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and
isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Fp8Op
@@ -1735,10 +1733,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and
isa<Float16Type>(tgtElemType)) {
// CvtPkScalePk8F16Bf8Op
@@ -1749,20 +1745,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and
isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Fp4Op
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
- auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and
isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Fp8Op
@@ -1773,10 +1765,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and
isa<BFloat16Type>(tgtElemType)) {
// CvtPkScalePk8Bf16Bf8Op
@@ -1787,20 +1777,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and
isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Fp4Op
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
- auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and
isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Fp8Op
@@ -1811,10 +1797,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and
isa<Float32Type>(tgtElemType)) {
// CvtPkScalePk8F32Bf8Op
@@ -1825,10 +1809,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v2xi32 = VectorType::get(2, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
- auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
// smallT = [Fp6, Bf6]
// largeT = [F16, Bf16, F32]
@@ -1844,10 +1826,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and
isa<Float16Type>(tgtElemType)) {
// CvtPkScale16F16Bf6Op
@@ -1855,10 +1835,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and
isa<BFloat16Type>(tgtElemType)) {
// CvtPkScale16Bf16Fp6Op
@@ -1866,10 +1844,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and
isa<BFloat16Type>(tgtElemType)) {
// CvtPkScale16Bf16Bf6Op
@@ -1877,10 +1853,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and
isa<Float32Type>(tgtElemType)) {
// CvtPkScale16F32Fp6Op
@@ -1888,10 +1862,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and
isa<Float32Type>(tgtElemType)) {
// CvtPkScale16F32Bf6Op
@@ -1899,10 +1871,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
- auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create(
- rewriter, loc, op.getResult().getType(), castedSource, castedScale,
- scaleSel);
- rewriter.replaceOp(op, newOp);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
+ op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
return failure();
}
>From 4f27e043e39776b908e0ee4e079fd16fc7ef2f3a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:45:30 -0500
Subject: [PATCH 21/43] Use method instead of isa
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 46 +++++++------------
1 file changed, 16 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index e28b53d8ed1a5..53f7accdb5a54 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1699,7 +1699,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
auto targetType = cast<VectorType>(op.getResult().getType());
- auto tgtElemType = cast<FloatType>(targetType.getElementType());
+ auto destElemType = cast<FloatType>(targetType.getElementType());
Location loc = op.getLoc();
// %scale: vector<4xf8E8M0FNU>
// ===========================
@@ -1716,15 +1716,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// largeT = [F16, Bf16, F32]
// CvtPkScalePk8${largeT}${smallT}
- if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+ if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp4Op
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and
- isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1735,8 +1734,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and
- isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1747,16 +1745,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) and
- isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Fp4Op
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and
- isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1767,8 +1763,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and
- isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1779,16 +1774,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) and
- isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Fp4Op
// i32
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and
- isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Fp8Op
// vector<8xf8E4M3FN>
Value source = adaptor.getSource();
@@ -1799,8 +1792,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and
- isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Bf8Op
// vector<8xf8E5M2>
Value source = adaptor.getSource();
@@ -1819,8 +1811,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// Bf6 = Float6E3M2FN
// CvtPkScalePk16${largeT}${smallT}
- else if (isa<Float6E2M3FNType>(srcElemType) and
- isa<Float16Type>(tgtElemType)) {
+ else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1828,8 +1819,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and
- isa<Float16Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1837,8 +1827,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) and
- isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1846,8 +1835,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and
- isa<BFloat16Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1855,8 +1843,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) and
- isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Fp6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
@@ -1864,8 +1851,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and
- isa<Float32Type>(tgtElemType)) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Bf6Op
Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
>From 0f8f3c493817873ca8047b7f2b788d1845716f09 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:47:59 -0500
Subject: [PATCH 22/43] Hoist variable
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +------------
1 file changed, 1 insertion(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 53f7accdb5a54..646d27a164830 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1715,6 +1715,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
//
// largeT = [F16, Bf16, F32]
// CvtPkScalePk8${largeT}${smallT}
+ Value source = adaptor.getSource();
if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp4Op
@@ -1726,7 +1727,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp8Op
// vector<8xf8E4M3FN>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1737,7 +1737,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Bf8Op
// vector<8xf8E5M2>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1755,7 +1754,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Fp8Op
// vector<8xf8E4M3FN>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1766,7 +1764,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Bf8Op
// vector<8xf8E5M2>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1784,7 +1781,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Fp8Op
// vector<8xf8E4M3FN>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1795,7 +1791,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Bf8Op
// vector<8xf8E5M2>
- Value source = adaptor.getSource();
// vector<2xi32>
VectorType v2xi32 = VectorType::get(2, i32);
@@ -1813,7 +1808,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// CvtPkScalePk16${largeT}${smallT}
else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Fp6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
@@ -1821,7 +1815,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Bf6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
@@ -1829,7 +1822,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Fp6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
@@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Bf6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
@@ -1845,7 +1836,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Fp6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
@@ -1853,7 +1843,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Bf6Op
- Value source = adaptor.getSource();
VectorType v3xi32 = VectorType::get(3, i32);
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
>From a7a853ea525814ef1f7707cdb99a6327fcbbbfa1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:59:31 -0500
Subject: [PATCH 23/43] Refactor NFC
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 92 ++++++++-----------
1 file changed, 40 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 646d27a164830..824e4249088ae 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1708,6 +1708,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+ Value source = adaptor.getSource();
+ Type packedType;
+ if (isa<Float4E2M1FNType>(srcElemType)) {
+ packedType = i32;
+ } else if (isa<Float8E4M3FNType>(srcElemType) ||
+ isa<Float8E5M2Type>(srcElemType)) {
+ packedType = VectorType::get(2, i32);
+ } else if (isa<Float6E2M3FNType>(srcElemType) ||
+ isa<Float6E3M2FNType>(srcElemType)) {
+ packedType = VectorType::get(3, i32);
+ } else {
+ llvm_unreachable("invalid element type for scaled ext");
+ }
// Ok, so we need to construct ops depending on the sourceType and targetType.
// smallT = [Fp4, Fp8, Bf8]
// Bf8 = E5M2
@@ -1715,87 +1728,68 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
//
// largeT = [F16, Bf16, F32]
// CvtPkScalePk8${largeT}${smallT}
- Value source = adaptor.getSource();
if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp4Op
// i32
Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Fp8Op
// vector<8xf8E4M3FN>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
// CvtPkScalePk8F16Bf8Op
// vector<8xf8E5M2>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Fp4Op
// i32
Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Fp8Op
// vector<8xf8E4M3FN>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
// CvtPkScalePk8Bf16Bf8Op
// vector<8xf8E5M2>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Fp4Op
// i32
Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Fp8Op
// vector<8xf8E4M3FN>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
// CvtPkScalePk8F32Bf8Op
// vector<8xf8E5M2>
-
- // vector<2xi32>
- VectorType v2xi32 = VectorType::get(2, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
@@ -1808,44 +1802,38 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// CvtPkScalePk16${largeT}${smallT}
else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Fp6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
// CvtPkScale16F16Bf6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Fp6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
// CvtPkScale16Bf16Bf6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Fp6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
// CvtPkScale16F32Bf6Op
- VectorType v3xi32 = VectorType::get(3, i32);
- Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
>From 9e4ab0e7b03009fe50cddaf8f218e93fe0bc82f1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:02:38 -0500
Subject: [PATCH 24/43] Hoist variable. NFC
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 56 +------------------
1 file changed, 2 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 824e4249088ae..3d41d47da6e00 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1728,68 +1728,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
//
// largeT = [F16, Bf16, F32]
// CvtPkScalePk8${largeT}${smallT}
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
- // CvtPkScalePk8F16Fp4Op
- // i32
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
- // CvtPkScalePk8F16Fp8Op
- // vector<8xf8E4M3FN>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
- // CvtPkScalePk8F16Bf8Op
- // vector<8xf8E5M2>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
- // CvtPkScalePk8Bf16Fp4Op
- // i32
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
- // CvtPkScalePk8Bf16Fp8Op
- // vector<8xf8E4M3FN>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
- // CvtPkScalePk8Bf16Bf8Op
- // vector<8xf8E5M2>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
- // CvtPkScalePk8F32Fp4Op
- // i32
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
- // CvtPkScalePk8F32Fp8Op
- // vector<8xf8E4M3FN>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
- // CvtPkScalePk8F32Bf8Op
- // vector<8xf8E5M2>
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
@@ -1801,39 +1767,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// CvtPkScalePk16${largeT}${smallT}
else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
- // CvtPkScale16F16Fp6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
- // CvtPkScale16F16Bf6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
- // CvtPkScale16Bf16Fp6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
- // CvtPkScale16Bf16Bf6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
- // CvtPkScale16F32Fp6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
- // CvtPkScale16F32Bf6Op
- Value castedSource =
- LLVM::BitcastOp::create(rewriter, loc, packedType, source);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
>From 728686f4ab79a5f64c389c83fc47447f0b69b663 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:03:55 -0500
Subject: [PATCH 25/43] Comments. NFC
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3d41d47da6e00..0affa2ead9f78 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,7 +1721,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
} else {
llvm_unreachable("invalid element type for scaled ext");
}
- // Ok, so we need to construct ops depending on the sourceType and targetType.
// smallT = [Fp4, Fp8, Bf8]
// Bf8 = E5M2
// Fp8 = E4M3
@@ -1760,11 +1759,10 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
// smallT = [Fp6, Bf6]
+ // Fp6 = Float6E2M3FN
+ // Bf6 = Float6E3M2FN
// largeT = [F16, Bf16, F32]
//
- // Fp6 = Float6E2M3FN
- // Bf6 = Float6E3M2FN
-
// CvtPkScalePk16${largeT}${smallT}
else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
>From 47dc32e6698b7ea8f3926bc39d97273b32a012df Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:09:01 -0500
Subject: [PATCH 26/43] refactor. nfc
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 33 +++++++++----------
1 file changed, 15 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0affa2ead9f78..ec889879fd8d4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1701,9 +1701,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
auto targetType = cast<VectorType>(op.getResult().getType());
auto destElemType = cast<FloatType>(targetType.getElementType());
Location loc = op.getLoc();
- // %scale: vector<4xf8E8M0FNU>
- // ===========================
- // %scale: i32
IntegerType i32 = rewriter.getI32Type();
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
@@ -1730,31 +1727,31 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
- if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
+ if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
+ } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
+ } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
+ } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
+ } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
+ } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
@@ -1764,22 +1761,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
// largeT = [F16, Bf16, F32]
//
// CvtPkScalePk16${largeT}${smallT}
- else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
+ else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
+ } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
+ } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
>From 3cfea7e2abf631970b8a115c3ed1f7d56267ad38 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:14:44 -0500
Subject: [PATCH 27/43] Keep conventions
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ec889879fd8d4..6dbf57342cb3f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1709,12 +1709,15 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
Type packedType;
if (isa<Float4E2M1FNType>(srcElemType)) {
packedType = i32;
+ packedType = getTypeConverter()->convertType(packedType);
} else if (isa<Float8E4M3FNType>(srcElemType) ||
isa<Float8E5M2Type>(srcElemType)) {
packedType = VectorType::get(2, i32);
+ packedType = getTypeConverter()->convertType(packedType);
} else if (isa<Float6E2M3FNType>(srcElemType) ||
isa<Float6E3M2FNType>(srcElemType)) {
packedType = VectorType::get(3, i32);
+ packedType = getTypeConverter()->convertType(packedType);
} else {
llvm_unreachable("invalid element type for scaled ext");
}
>From 2b010cd04d4644b06a864777de0c63edae4be0c6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:35:11 -0500
Subject: [PATCH 28/43] Less of exhaustive enumeration
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 71 +++++--------------
1 file changed, 19 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6dbf57342cb3f..a9f58063ac32b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1627,62 +1627,29 @@ int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;
- if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
- return 0b000;
- }
- if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
- return 0b001;
- }
- if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
- return 0b010;
- }
- if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
- return 0b011;
- }
- if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
- return 0b100;
- }
- if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
- return 0b101;
- }
- if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
- return 0b110;
- }
- if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
- return 0b111;
- }
-
- if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
- return 0b0000;
- }
- if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
- return 0b0001;
- }
- if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
- return 0b0010;
- }
- if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
- return 0b0100;
- }
- if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
- return 0b0110;
- }
- if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
- return 0b1010;
- }
- if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
- return 0b1100;
- }
- if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
- return 0b1101;
- }
- if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
- return 0b1110;
+ if (!is_fp8) {
+ int bit_0 = is_block_16;
+ assert(llvm::is_contained({0, 2}, firstScaleByte));
+ int bit_1 = (firstScaleByte == 2) << 1;
+ assert(llvm::is_contained({0, 1}, firstScaleLane));
+ int bit_2 = firstScaleLane << 2;
+ return bit_2 | bit_1 | bit_0;
+ } else {
+ int bit_0 = is_block_16;
+ // firstScaleByte is guaranteed to be defined by two bits.
+ assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+ int bit_2_and_1 = firstScaleByte << 1;
+ assert(llvm::is_contained({0, 1}, firstScaleLane));
+ int bit_3 = firstScaleLane << 3;
+ int bits = bit_3 | bit_2_and_1 | bit_0;
+ // These are invalid cases.
+ assert(!llvm::is_contained(
+ {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+ return bits;
}
llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
"blockSize and type.");
- return 0;
}
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
>From 6f07ef03b6f778fec82abc85914c6b47095d2312 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:39:10 -0500
Subject: [PATCH 29/43] Correct types
---
.../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a9f58063ac32b..596cac1b76469 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1613,8 +1613,8 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
-int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
- int firstScaleByte) {
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+ int32_t firstScaleLane, int32_t firstScaleByte) {
// When lowering amdgpu.scaled_ext_packed816 to
// rocdl.cvt.scale.pk*.f*.f* operations, the
// attributes blockSize, sourceType, firstScaleLane and firstScaleByte
@@ -1656,13 +1656,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- int firstScaleLane = op.getFirstScaleLane();
- int firstScaleByte = op.getFirstScaleByte();
- int blockSize = op.getBlockSize();
+ int32_t firstScaleLane = op.getFirstScaleLane();
+ int32_t firstScaleByte = op.getFirstScaleByte();
+ int32_t blockSize = op.getBlockSize();
auto sourceType = cast<VectorType>(op.getSource().getType());
auto srcElemType = cast<FloatType>(sourceType.getElementType());
- int bitWidth = srcElemType.getWidth();
- int scaleSel =
+ unsigned bitWidth = srcElemType.getWidth();
+ int32_t scaleSel =
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
auto targetType = cast<VectorType>(op.getResult().getType());
>From 69787933cc44ed01faff7abf38078c0f61917bbd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:39:48 -0500
Subject: [PATCH 30/43] Reflow comment
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 596cac1b76469..a7c9a68dd7731 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1615,12 +1615,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
int32_t firstScaleLane, int32_t firstScaleByte) {
- // When lowering amdgpu.scaled_ext_packed816 to
- // rocdl.cvt.scale.pk*.f*.f* operations, the
- // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
- // are merged into a single attribute scaleSel.
- //
- // This is how those values are merged together.
+ // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+ // operations, the attributes blockSize, sourceType, firstScaleLane and
+ // firstScaleByte are merged into a single attribute scaleSel. This is how
+ // those values are merged together.
assert(llvm::is_contained({16, 32}, blockSize));
assert(llvm::is_contained({4, 6, 8}, bitWidth));
>From ed66571310d9fd687e1b549beb5a79704c629cb4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:40:29 -0500
Subject: [PATCH 31/43] superfluous empty line
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a7c9a68dd7731..15c511460a552 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1653,7 +1653,6 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
-
int32_t firstScaleLane = op.getFirstScaleLane();
int32_t firstScaleByte = op.getFirstScaleByte();
int32_t blockSize = op.getBlockSize();
>From 33ef57e0dce2640dc8c3cf3c5623ffc71eb42d18 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:55:53 -0500
Subject: [PATCH 32/43] Using
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 58 ++++++++-----------
1 file changed, 24 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 15c511460a552..b0b0d4a9fe604 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1620,7 +1620,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
// firstScaleByte are merged into a single attribute scaleSel. This is how
// those values are merged together.
assert(llvm::is_contained({16, 32}, blockSize));
- assert(llvm::is_contained({4, 6, 8}, bitWidth));
+ assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;
@@ -1653,6 +1653,11 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
int32_t firstScaleLane = op.getFirstScaleLane();
int32_t firstScaleByte = op.getFirstScaleByte();
int32_t blockSize = op.getBlockSize();
@@ -1671,79 +1676,64 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
Value source = adaptor.getSource();
Type packedType;
- if (isa<Float4E2M1FNType>(srcElemType)) {
+ if (isa<fp4>(srcElemType)) {
packedType = i32;
packedType = getTypeConverter()->convertType(packedType);
- } else if (isa<Float8E4M3FNType>(srcElemType) ||
- isa<Float8E5M2Type>(srcElemType)) {
+ } else if (isa<fp8, bf8>(srcElemType)) {
packedType = VectorType::get(2, i32);
packedType = getTypeConverter()->convertType(packedType);
- } else if (isa<Float6E2M3FNType>(srcElemType) ||
- isa<Float6E3M2FNType>(srcElemType)) {
+ } else if (isa<fp6, bf6>(srcElemType)) {
packedType = VectorType::get(3, i32);
packedType = getTypeConverter()->convertType(packedType);
} else {
llvm_unreachable("invalid element type for scaled ext");
}
- // smallT = [Fp4, Fp8, Bf8]
- // Bf8 = E5M2
- // Fp8 = E4M3
- //
- // largeT = [F16, Bf16, F32]
- // CvtPkScalePk8${largeT}${smallT}
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
- if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
+ if (isa<fp4>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
+ } else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
+ } else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
+ } else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
+ } else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
+ } else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
+ } else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
+ } else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
+ } else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- }
- // smallT = [Fp6, Bf6]
- // Fp6 = Float6E2M3FN
- // Bf6 = Float6E3M2FN
- // largeT = [F16, Bf16, F32]
- //
- // CvtPkScalePk16${largeT}${smallT}
- else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
+ } else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
+ } else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
+ } else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
+ } else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
+ } else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
+ } else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
>From a83cec94451eae1889c92e333dd6fa7b47904bad Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:18:27 -0500
Subject: [PATCH 33/43] Add chipset check and moved tests
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 8 +-
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 114 -----------------
.../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 116 ++++++++++++++++++
3 files changed, 123 insertions(+), 115 deletions(-)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b0b0d4a9fe604..7e30feb520a72 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1658,6 +1658,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
+ Location loc = op.getLoc();
+ if (chipset != Chipset{12, 5, 0}) {
+ return rewriter.notifyMatchFailure(
+ loc,
+ "Scaled fp packed conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ }
int32_t firstScaleLane = op.getFirstScaleLane();
int32_t firstScaleByte = op.getFirstScaleByte();
int32_t blockSize = op.getBlockSize();
@@ -1669,7 +1676,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
auto targetType = cast<VectorType>(op.getResult().getType());
auto destElemType = cast<FloatType>(targetType.getElementType());
- Location loc = op.getLoc();
IntegerType i32 = rewriter.getI32Type();
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 94a04d98004c7..432b8876696a9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -457,117 +457,3 @@ func.func @sched_barrier() {
func.return
}
-// CHECK-LABEL: @scaled_ext_packed816_fp4
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
- // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
- // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
- // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
- func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_fp8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
-
- func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_bf8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
- // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
- func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-
-// CHECK-LABEL: @scaled_ext_packed816_fp6
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
- // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
- return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_bf6
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
- // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
- // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-
- // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
- // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
- // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
- return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
-}
-
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
new file mode 100644
index 0000000000000..811a8e49dc5c6
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+
+// CHECK-LABEL: @scaled_ext_packed816_fp4
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+
+// CHECK-LABEL: @scaled_ext_packed816_fp6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+
>From 34ed3e9384a683e44b967baff53fb952b3320e90 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:21:24 -0500
Subject: [PATCH 34/43] Refactor NFC
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 25 ++++++++-----------
1 file changed, 11 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e30feb520a72..0c5f4ff7f8227 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1632,22 +1632,19 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
assert(llvm::is_contained({0, 1}, firstScaleLane));
int bit_2 = firstScaleLane << 2;
return bit_2 | bit_1 | bit_0;
- } else {
- int bit_0 = is_block_16;
- // firstScaleByte is guaranteed to be defined by two bits.
- assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
- int bit_2_and_1 = firstScaleByte << 1;
- assert(llvm::is_contained({0, 1}, firstScaleLane));
- int bit_3 = firstScaleLane << 3;
- int bits = bit_3 | bit_2_and_1 | bit_0;
- // These are invalid cases.
- assert(!llvm::is_contained(
- {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
- return bits;
}
- llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
- "blockSize and type.");
+ int bit_0 = is_block_16;
+ // firstScaleByte is guaranteed to be defined by two bits.
+ assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+ int bit_2_and_1 = firstScaleByte << 1;
+ assert(llvm::is_contained({0, 1}, firstScaleLane));
+ int bit_3 = firstScaleLane << 3;
+ int bits = bit_3 | bit_2_and_1 | bit_0;
+ // These are invalid cases.
+ assert(!llvm::is_contained(
+ {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+ return bits;
}
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
>From b88f7f6e74c0038873dabec3d31b2bba12b8e6ab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:44:41 -0500
Subject: [PATCH 35/43] Use operation name
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 105 ++++++++++--------
1 file changed, 57 insertions(+), 48 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0c5f4ff7f8227..23660361094c3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1647,6 +1647,46 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
return bits;
}
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ if (isa<fp4>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+ if (isa<fp4>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+ if (isa<fp4>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+ return std::nullopt;
+}
+
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1694,54 +1734,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
- if (isa<fp4>(srcElemType) && destElemType.isF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
- op, op.getResult().getType(), castedSource, castedScale, scaleSel);
- } else {
- return failure();
- }
+ std::optional<StringRef> maybeIntrinsic =
+ scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching packed scaled conversion on the given chipset");
+
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes({op.getResult().getType()});
+ loweredOp.addOperands({castedSource, castedScale});
+
+ SmallVector<NamedAttribute, 1> attrs;
+ attrs.push_back(
+ NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
+
+ loweredOp.addAttributes(attrs);
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered);
return success();
}
>From 7a7ecaf31b5ad267343b871d5f2f44b8eff875b3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:48:38 -0500
Subject: [PATCH 36/43] Convert result type
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 23660361094c3..2e73f7b0d4266 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1741,7 +1741,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
"no intrinsic matching packed scaled conversion on the given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes({op.getResult().getType()});
+ Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+ loweredOp.addTypes({llvmResultType});
loweredOp.addOperands({castedSource, castedScale});
SmallVector<NamedAttribute, 1> attrs;
>From 1025e2b999b7f42947842cfb3185b921a0ea534d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:58:20 -0500
Subject: [PATCH 37/43] Check for type conversion failures
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2e73f7b0d4266..ae907b8ffefc3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1718,7 +1718,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
Value source = adaptor.getSource();
- Type packedType;
+ Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+ Type packedType = nullptr;
if (isa<fp4>(srcElemType)) {
packedType = i32;
packedType = getTypeConverter()->convertType(packedType);
@@ -1729,8 +1730,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
packedType = VectorType::get(3, i32);
packedType = getTypeConverter()->convertType(packedType);
} else {
- llvm_unreachable("invalid element type for scaled ext");
+ llvm_unreachable("invalid element type for packed scaled ext");
+ }
+
+ if (!packedType || !llvmResultType) {
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
}
+
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
@@ -1741,7 +1747,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
"no intrinsic matching packed scaled conversion on the given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
- Type llvmResultType = typeConverter->convertType(op.getResult().getType());
loweredOp.addTypes({llvmResultType});
loweredOp.addOperands({castedSource, castedScale});
>From 7c44f0959a85b0a801a687e781605df81e3b1c6d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:21:25 -0500
Subject: [PATCH 38/43] Add top-level if condition for each src type
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 78 +++++++++++--------
1 file changed, 47 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f9bbc6535f0e1..d23a89a199131 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1667,37 +1667,53 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
- if (isa<fp4>(srcElemType) && destElemType.isF16())
- return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
- if (isa<fp8>(srcElemType) && destElemType.isF16())
- return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
- if (isa<bf8>(srcElemType) && destElemType.isF16())
- return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
- if (isa<fp4>(srcElemType) && destElemType.isBF16())
- return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
- if (isa<fp8>(srcElemType) && destElemType.isBF16())
- return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
- if (isa<bf8>(srcElemType) && destElemType.isBF16())
- return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
- if (isa<fp4>(srcElemType) && destElemType.isF32())
- return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
- if (isa<fp8>(srcElemType) && destElemType.isF32())
- return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
- if (isa<bf8>(srcElemType) && destElemType.isF32())
- return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
- if (isa<fp6>(srcElemType) && destElemType.isF16())
- return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
- if (isa<bf6>(srcElemType) && destElemType.isF16())
- return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
- if (isa<fp6>(srcElemType) && destElemType.isBF16())
- return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
- if (isa<bf6>(srcElemType) && destElemType.isBF16())
- return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
- if (isa<fp6>(srcElemType) && destElemType.isF32())
- return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
- if (isa<bf6>(srcElemType) && destElemType.isF32())
- return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
- return std::nullopt;
+ if (isa<fp4>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+ return std::nullopt;
+ }
+ llvm_unreachable("invalid combination of element types for packed conversion "
+ "instructions");
}
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
>From f06a67e94e08dd242c6a89dd06023b6ce5d95ad0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:37:30 -0500
Subject: [PATCH 39/43] Add chipset constant at beginning of file
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d23a89a199131..4d1734fe710c0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
+constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
k, isRDNA3);
// Handle gfx1250.
- if (chipset == Chipset{12, 5, 0})
+ if (chipset == kGfx1250)
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);
@@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
- bool isGFX1250 = chipset >= Chipset(12, 5, 0);
+ bool isGFX1250 = chipset >= kGfx1250;
// The WMMA operations represent vectors of bf16s as vectors of i16s
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
@@ -1725,7 +1726,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
Location loc = op.getLoc();
- if (chipset != Chipset{12, 5, 0}) {
+ if (chipset != kGfx1250) {
return rewriter.notifyMatchFailure(
loc,
"Scaled fp packed conversion instructions are not available on target "
>From 1dbcb95412e274d598c337b700b2524130fd3d25 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:42:14 -0500
Subject: [PATCH 40/43] wip
---
mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 811a8e49dc5c6..87e4eb363d343 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \
+// RUN: | FileCheck %s
// CHECK-LABEL: @scaled_ext_packed816_fp4
// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
>From eee0ce97137e993bed9450d281536e74d21d8108 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:52:55 -0500
Subject: [PATCH 41/43] Add invalid srcElemType case
---
.../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 40 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/invalid.mlir | 32 ---------------
2 files changed, 40 insertions(+), 32 deletions(-)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 87e4eb363d343..73711a2b98ac9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -115,3 +115,43 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
}
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op operand #0 must be}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ return %ret0: vector<16xf16>
+}
+
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5c8cc8b67c4b3..61fdf29a78cbd 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -333,38 +333,6 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
- func.return
-}
-
-// -----
-
func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// expected-error at +1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
%0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
>From 9860cdd9ffb108f472a18dff751a3401e13f695e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 14:28:08 -0500
Subject: [PATCH 42/43] Update verifiers
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 45 ++++++++++++-------
.../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 7 +++
3 files changed, 37 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4d1734fe710c0..bf83591bf6047 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1634,7 +1634,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
// firstScaleByte are merged into a single attribute scaleSel. This is how
// those values are merged together.
assert(llvm::is_contained({16, 32}, blockSize));
- assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 5c35823678576..955de3bb861ba 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
LogicalResult ScaledExtPacked816Op::verify() {
int blockSize = getBlockSize();
- assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+ assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
int firstScaleByte = getFirstScaleByte();
+ int firstScaleLane = getFirstScaleLane();
auto sourceType = cast<VectorType>(getSource().getType());
Type elementType = sourceType.getElementType();
auto floatType = cast<FloatType>(elementType);
- int bitWidth = floatType.getWidth();
+ unsigned bitWidth = floatType.getWidth();
- if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
- !llvm::is_contained({0, 1}, firstScaleByte)) {
- return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
- "for f4 and f6.");
- }
- if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
- !llvm::is_contained({0, 2}, firstScaleByte)) {
- return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
- "for f4 and f6.");
- }
- if (bitWidth == 8 && blockSize == 16 &&
- !llvm::is_contained({0, 2}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8) {
+ if (is_block_16) {
+ if (!llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
+ "or 1 for f4 and f6.");
+ }
+ } else {
+ if (!llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
+ "or 2 for f4 and f6.");
+ }
+ }
+ } else {
+ if (is_block_16) {
+ bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
+ ((firstScaleLane == 1) && (firstScaleByte == 2));
+ if (!is_valid) {
+ return emitOpError(
+ "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+ }
+ }
}
return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 73711a2b98ac9..fbe13a29c53ab 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -155,3 +155,10 @@ func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>,
return %ret0: vector<16xf16>
}
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) {
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op result #0 must be vector}}
+ %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
+ return %ret0: vector<16xf64>
+}
>From 0b8f561e00d2cea1a6e60e53e13720a617611330 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 14:35:10 -0500
Subject: [PATCH 43/43] Update verifier message
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 ++--
mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 955de3bb861ba..d55f3cec47c1f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -374,8 +374,8 @@ LogicalResult ScaledExtPacked816Op::verify() {
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
((firstScaleLane == 1) && (firstScaleByte == 2));
if (!is_valid) {
- return emitOpError(
- "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+ return emitOpError("blockSize of 16 can only have (firstScaleLane, "
+ "firstScaleByte) be (0, 0) or (1, 2) for f8.");
}
}
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index fbe13a29c53ab..d2391140ce056 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -134,7 +134,7 @@ func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_3
// -----
func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+ // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (1, 2) for f8.}}
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
func.return
}
More information about the Mlir-commits
mailing list