[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 17 08:10:56 PST 2025
================
@@ -1613,6 +1613,154 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+ int32_t firstScaleLane, int32_t firstScaleByte) {
+ // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+ // operations, the attributes blockSize, sourceType, firstScaleLane and
+ // firstScaleByte are merged into a single attribute scaleSel. This is how
+ // those values are merged together.
+ assert(llvm::is_contained({16, 32}, blockSize));
+ assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8) {
+ int bit_0 = is_block_16;
+ assert(llvm::is_contained({0, 2}, firstScaleByte));
+ int bit_1 = (firstScaleByte == 2) << 1;
+ assert(llvm::is_contained({0, 1}, firstScaleLane));
+ int bit_2 = firstScaleLane << 2;
+ return bit_2 | bit_1 | bit_0;
+ }
+
+ int bit_0 = is_block_16;
+ // firstScaleByte is guaranteed to be defined by two bits.
+ assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+ int bit_2_and_1 = firstScaleByte << 1;
+ assert(llvm::is_contained({0, 1}, firstScaleLane));
+ int bit_3 = firstScaleLane << 3;
+ int bits = bit_3 | bit_2_and_1 | bit_0;
+ // These are invalid cases.
+ assert(!llvm::is_contained(
+ {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+ return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ if (isa<fp4>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+ if (isa<fp4>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+ if (isa<fp4>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+ if (isa<fp8>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+ if (isa<bf8>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+ if (isa<fp6>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+ if (isa<bf6>(srcElemType) && destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+ return std::nullopt;
+}
+
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+ ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ Location loc = op.getLoc();
+ if (chipset != Chipset{12, 5, 0}) {
----------------
kuhar wrote:
Maybe define a constant for it at the very top of the file?
https://github.com/llvm/llvm-project/pull/168123
More information about the Mlir-commits
mailing list