[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)
Umang Yadav
llvmlistbot at llvm.org
Fri May 30 11:47:55 PDT 2025
================
@@ -409,6 +421,112 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto inputOperand = op.getIn();
+ auto scaleOperand = op.getScale();
+ if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+ return rewriter.notifyMatchFailure(
+ op, "scaling extf is not using scale operand of type f8E8M0FNU");
+ }
+ Type resultTy = op.getType();
+ // extf on scale will essentially create f32 number that is 2^scale and will
+ // also propagate NaNs
+ Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
+ Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
+ Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ScalingTruncFOpConverter
+ : public OpRewritePattern<arith::ScalingTruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto inputOperand = op.getIn();
+ auto scaleOperand = op.getScale();
+ if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+ return rewriter.notifyMatchFailure(
+ op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+ }
+ auto scaleTy = scaleOperand.getType();
+
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(op.getOut());
+
+ Type inputTy = inputOperand.getType();
+ Type inputETy = getElementTypeOrSelf(inputOperand);
+
+ Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
+ Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
+
+ if (inputETy.getIntOrFloatBitWidth() < 32) {
+ inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
+ } else if (inputETy.getIntOrFloatBitWidth() > 32) {
+ inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
+ }
+ inputTy = inputOperand.getType();
+ inputETy = getElementTypeOrSelf(inputOperand);
+
+ // normalize scale by exponent of the max normal value in result type as per
+ // the OCP MXFP spec
+ // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
+ const llvm::fltSemantics &resultFltSemantics =
+ llvm::cast<FloatType>(resultETy).getFloatSemantics();
+ int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
+ Value cMaxNormalExponent =
+ createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+ Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
+ Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
+ Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
+ Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
+ Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
+ Value normalizedUnbiasedScale =
+ b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
+ // clamp scale exponent as per spec
+ // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
+ // upper clamp limit of 127 will be mapped to biased value of 255 and will
+ // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
+ // using arith.extf
+ Value clampUpperCond = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
+ Value clampLowerCond = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
+ Value clampedScale = b.create<arith::SelectOp>(
+ clampUpperCond, c127,
+ b.create<arith::SelectOp>(clampLowerCond, cNeg127,
+ normalizedUnbiasedScale));
+ Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
+ Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
+ Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
+ Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
+ // flush denorms by checking if exponent part of input operand is zero
+ // or not.
+ Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
+ Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
+ Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+ Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
+ inputExponentU8);
+ Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
+ Value flushedInput =
----------------
umangyadav wrote:
Rewrote using f32. It does simplify things a bit. Thanks
https://github.com/llvm/llvm-project/pull/141965
More information about the Mlir-commits
mailing list