[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Daniel Hernandez-Juarez llvmlistbot at llvm.org
Mon Jun 2 05:35:34 PDT 2025


================
@@ -409,6 +421,125 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling extf is not using scale operand of type f8E8M0FNU");
+    }
+    Type resultTy = op.getType();
+    // extf on scale will essentially create f32 number that is 2^scale and will
+    // also propagate NaNs
+    Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
+    Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
+    Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalingTruncFOpConverter
+    : public OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
+    Type scaleTy = scaleOperand.getType();
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+      scaleTy = scaleOperand.getType();
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+    }
+
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(op.getOut());
+
+    Type inputTy = inputOperand.getType();
+    Type inputETy = getElementTypeOrSelf(inputOperand);
+
+    Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
+    Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
+
+    if (inputETy.getIntOrFloatBitWidth() < 32) {
+      inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
+    } else if (inputETy.getIntOrFloatBitWidth() > 32) {
+      inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
+    }
+    inputTy = inputOperand.getType();
+    inputETy = getElementTypeOrSelf(inputOperand);
+
+    // normalize scale by exponent of the max normal value in result type as per
+    // the OCP MXFP spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
+    const llvm::fltSemantics &resultFltSemantics =
+        llvm::cast<FloatType>(resultETy).getFloatSemantics();
+    int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
+    Value cMaxNormalExponent =
+        createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+    Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
+    Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
+    Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
+    Value scaleI32 = b.create<arith::ExtUIOp>(i32Ty, scaleI8);
+    Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
+    Value normalizedUnbiasedScale =
+        b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
+    // clamp scale exponent as per spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
+    // upper clamp limit of 127 will be mapped to biased value of 255 and will
+    // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
+    // using arith.extf
+    Value clampUpperCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
+    Value clampLowerCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
+    Value clampedScale = b.create<arith::SelectOp>(
+        clampUpperCond, c127,
+        b.create<arith::SelectOp>(clampLowerCond, cNeg127,
+                                  normalizedUnbiasedScale));
+    Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
+    Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
+    Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
+    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
+    // flush denorms by checking if exponent part of input operand is zero
+    // or not.
+    Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
+    Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
+    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
+                                            inputExponentU8);
+    Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
+    Value flushedInput =
+        b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
+    Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
+    // propagate rounding mode and fast math attributes
+    Value resultCast = b.create<arith::TruncFOp>(
+        resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
----------------
dhernandez0 wrote:

> No, why do you think so ? Output dtype will be whatever user has specified. 

I mean result of the function before truncation. result.dtype = f32, right?

> In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow.

https://github.com/llvm/llvm-project/pull/141965


More information about the Mlir-commits mailing list