[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu May 29 13:01:04 PDT 2025


================
@@ -409,6 +421,112 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling extf is not using scale operand of type f8E8M0FNU");
+    }
+    Type resultTy = op.getType();
+    // extf on scale will essentially create f32 number that is 2^scale and will
+    // also propagate NaNs
+    Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
+    Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
+    Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalingTruncFOpConverter
+    : public OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+    }
+    auto scaleTy = scaleOperand.getType();
+
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(op.getOut());
+
+    Type inputTy = inputOperand.getType();
+    Type inputETy = getElementTypeOrSelf(inputOperand);
+
+    Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
+    Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
+
+    if (inputETy.getIntOrFloatBitWidth() < 32) {
+      inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
+    } else if (inputETy.getIntOrFloatBitWidth() > 32) {
+      inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
+    }
+    inputTy = inputOperand.getType();
+    inputETy = getElementTypeOrSelf(inputOperand);
+
+    // normalize scale by exponent of the max normal value in result type as per
+    // the OCP MXFP spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
+    const llvm::fltSemantics &resultFltSemantics =
+        llvm::cast<FloatType>(resultETy).getFloatSemantics();
+    int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
+    Value cMaxNormalExponent =
+        createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+    Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
+    Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
+    Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
+    Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
----------------
krzysz00 wrote:

... But also, the code you linked is for quantization

I think it's reasonable to assume that someone implementing quantization will already have done the scale-biasing thing and so we don't need to do it here

Unless we have evidence that the hardware implementations perform the subtraction described here? (We'll probably want to go find the AMD behavior)

https://github.com/llvm/llvm-project/pull/141965


More information about the Mlir-commits mailing list