[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Umang Yadav llvmlistbot at llvm.org
Fri Jun 6 16:41:30 PDT 2025


================
@@ -1280,6 +1333,64 @@ def Arith_TruncFOp :
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp
+    : Arith_Op<"scaling_truncf",
+               [Pure, SameInputOutputTensorDims,
+                DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+                DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "Downcasts input floating point values using provided scales values following OCP MXFP Spec";
+  let description = [{
+    This operation quantizes input using the provided scale values. It expects 
+    both scales and the input operand to be of the same shape and, therefore, 
+    makes the operation elementwise. Scales are usually calculated per block 
+    following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+    Users are required to normalize and clamp the scales as necessary before calling
+    passing them to this operation.  OCP MXFP spec also does the flushing of denorms
+    on the input operand, which should be handled during lowering by passing appropriate 
+    fastMath flag to this operation. 
+
+    If scales are calculated per block where blockSize != 1, scales may require 
+    broadcasting to make this operation elementwise. For example, let's say the 
+    input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
+    assuming quantization happens on the last axis, the input can be reshaped to 
+    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
+    per block on the last axis. Therefore, scales will be of shape 
+    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
+    shape as long as it is broadcast compatible with the input, e.g., 
+    `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+    In this example, before calling into `arith.scaling_truncf`, scales must be 
+    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
+    that there could be multiple quantization axes. Internally, 
+    `arith.scaling_truncf` would perform the following:
+
+    ```
+    scaleTy = get_type(scale)
+    inputTy = get_type(input)
+    resultTy = get_type(result)
+    assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
----------------
umangyadav wrote:

Removed.

https://github.com/llvm/llvm-project/pull/141965


More information about the Mlir-commits mailing list