[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled floating point conversion ops fp8 (PR #141554)

Umang Yadav llvmlistbot at llvm.org
Tue Jun 10 06:00:08 PDT 2025


================
@@ -1230,6 +1257,157 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
+    ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+
+  Value source = adaptor.getSource();
+  Value scale = adaptor.getScale();
+
+  VectorType sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+  Type sourceElemType = getElementTypeOrSelf(op.getSource());
+  VectorType destVecType = dyn_cast<VectorType>(op.getResult().getType());
+  Type destElemType = getElementTypeOrSelf(op.getResult());
+
+  VectorType packedVecType;
+  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
+    VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
+    packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
+  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
+    VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
+    packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
+  } else {
+    llvm_unreachable("invalid element type for scaled ext");
+  }
+
+  // Extend to a packedVectorType
+  if (!sourceVecType ||
+      sourceVecType.getNumElements() < packedVecType.getNumElements()) {
+    Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+    if (!sourceVecType) {
+      longVec = rewriter.create<LLVM::InsertElementOp>(
+          loc, longVec, source, createI32Constant(rewriter, loc, 0));
+    } else {
+      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+        Value idx = createI32Constant(rewriter, loc, i);
+        Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+        longVec =
+            rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+      }
+    }
+    source = longVec;
+  }
+  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+
+  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
+    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
+        op, destVecType, i32Source, scale, op.getIndex());
+  else
+    return failure();
+
+  return success();
+}
+
+LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
+    PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v2i16 = getTypeConverter()->convertType(
+      VectorType::get(2, rewriter.getI16Type()));
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+
+  Type resultType = op.getResult().getType();
+  Type resultElemType = getElementTypeOrSelf(resultType);
+  Type sourceElemType = getElementTypeOrSelf(op.getSource());
+
+  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
+
+  Value source = adaptor.getSource();
+  Value scale = adaptor.getScale();
+  Value existing = adaptor.getExisting();
+  if (existing)
+    existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
+  else
+    existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
+
+  Value sourceA, sourceB;
+  if (sourceElemType.isF32()) {
+    Value c0 = createI32Constant(rewriter, loc, 0);
+    Value c1 = createI32Constant(rewriter, loc, 1);
+    sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
+    sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
+  }
+
+  Value result;
+  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
+        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
+        loc, intResultType, existing, source, scale, op.getIndex());
+  else
+    return failure();
----------------
umangyadav wrote:

```suggestion
if (isa<Float8E5M2Type>(resultElemType)) {
  if (sourceElemType.isF32()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
  } else if (sourceElemType.isF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else if (sourceElemType.isBF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else {
    return failure();
  }
} else if (isa<Float8E4M3FNType>(resultElemType)) {
  if (sourceElemType.isF32()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
  } else if (sourceElemType.isF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else if (sourceElemType.isBF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else {
    return failure();
  }
} else if (isa<Float4E2M1FNType>(resultElemType)) {
  if (sourceElemType.isF32()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
        loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
  } else if (sourceElemType.isF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else if (sourceElemType.isBF16()) {
    result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
        loc, intResultType, existing, source, scale, op.getIndex());
  } else {
    return failure();
  }
} else {
  return failure();
}
```

https://github.com/llvm/llvm-project/pull/141554


More information about the Mlir-commits mailing list