[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 17 08:10:56 PST 2025


================
@@ -1613,6 +1613,154 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+                    int32_t firstScaleLane, int32_t firstScaleByte) {
+  // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+  // operations, the attributes blockSize, sourceType, firstScaleLane and
+  // firstScaleByte are merged into a single attribute scaleSel. This is how
+  // those values are merged together.
+  assert(llvm::is_contained({16, 32}, blockSize));
+  assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8) {
+    int bit_0 = is_block_16;
+    assert(llvm::is_contained({0, 2}, firstScaleByte));
+    int bit_1 = (firstScaleByte == 2) << 1;
+    assert(llvm::is_contained({0, 1}, firstScaleLane));
+    int bit_2 = firstScaleLane << 2;
+    return bit_2 | bit_1 | bit_0;
+  }
+
+  int bit_0 = is_block_16;
+  // firstScaleByte is guaranteed to be defined by two bits.
+  assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+  int bit_2_and_1 = firstScaleByte << 1;
+  assert(llvm::is_contained({0, 1}, firstScaleLane));
+  int bit_3 = firstScaleLane << 3;
+  int bits = bit_3 | bit_2_and_1 | bit_0;
+  // These are invalid cases.
+  assert(!llvm::is_contained(
+      {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+  return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  if (isa<fp4>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+  return std::nullopt;
+}
+
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+    ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  Location loc = op.getLoc();
+  if (chipset != Chipset{12, 5, 0}) {
+    return rewriter.notifyMatchFailure(
+        loc,
+        "Scaled fp packed conversion instructions are not available on target "
+        "architecture and their emulation is not implemented");
+  }
+  int32_t firstScaleLane = op.getFirstScaleLane();
+  int32_t firstScaleByte = op.getFirstScaleByte();
+  int32_t blockSize = op.getBlockSize();
+  auto sourceType = cast<VectorType>(op.getSource().getType());
+  auto srcElemType = cast<FloatType>(sourceType.getElementType());
+  unsigned bitWidth = srcElemType.getWidth();
+  int32_t scaleSel =
+      getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+  auto targetType = cast<VectorType>(op.getResult().getType());
+  auto destElemType = cast<FloatType>(targetType.getElementType());
+  IntegerType i32 = rewriter.getI32Type();
+  Value castedScale =
+      LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
+  Value source = adaptor.getSource();
+  Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+  Type packedType = nullptr;
+  if (isa<fp4>(srcElemType)) {
+    packedType = i32;
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<fp8, bf8>(srcElemType)) {
+    packedType = VectorType::get(2, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<fp6, bf6>(srcElemType)) {
+    packedType = VectorType::get(3, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else {
+    llvm_unreachable("invalid element type for packed scaled ext");
----------------
kuhar wrote:

Are you sure the remaining cases are guaranteed not to get here through the operation verifiers? If the verifier doesn't prevent these, we can't assert and need to return an error instead

https://github.com/llvm/llvm-project/pull/168123


More information about the Mlir-commits mailing list