[Mlir-commits] [mlir] [mlir][arith][nfc] Adding examples to scaling_extf/truncf descriptions (PR #163980)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 17 08:41:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Muzammil (Muzammiluddin-Syed-ECE)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/163980.diff


1 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+41-19) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..dcac54e074c4d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,25 +1229,25 @@ def Arith_ScalingExtFOp
   let summary = "Upcasts input floats using provided scales values following "
                 "OCP MXFP Spec";
   let description = [{
-  This operation upcasts input floating-point values using provided scale
-  values. It expects both scales and the input operand to be of the same shape,
-  making the operation elementwise. Scales are usually calculated per block
-  following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
-
-  If scales are calculated per block where blockSize != 1, then scales may
-  require broadcasting to make this operation elementwise. For example, let's
-  say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
-  assuming quantization happens on the last axis, the input can be reshaped to
-  `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
-  per block on the last axis. Therefore, scales will be of shape
-  `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
-  shape as long as it is broadcast compatible with the input, e.g.,
-  `<1 x 1 x ... (dimN/blockSize) x 1>`.
-
-  In this example, before calling into `arith.scaling_extf`, scales must be
-  broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
-  that there could be multiple quantization axes. Internally,
-  `arith.scaling_extf` would perform the following:
+    This operation upcasts input floating-point values using provided scale
+    values. It expects both scales and the input operand to be of the same shape,
+    making the operation elementwise. Scales are usually calculated per block
+    following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+    If scales are calculated per block where blockSize != 1, then scales may
+    require broadcasting to make this operation elementwise. For example, let's
+    say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+    assuming quantization happens on the last axis, the input can be reshaped to
+    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+    per block on the last axis. Therefore, scales will be of shape
+    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+    shape as long as it is broadcast compatible with the input, e.g.,
+    `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+    In this example, before calling into `arith.scaling_extf`, scales must be
+    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+    that there could be multiple quantization axes. Internally,
+    `arith.scaling_extf` would perform the following:
 
     ```
     resultTy = get_type(result)
@@ -1260,6 +1260,17 @@ def Arith_ScalingExtFOp
     ```
     It propagates NaN values. Therefore, if either scale or the input element
     contains NaN, then the output element value will also be a NaN.
+
+    Example:
+
+    ```mlir
+    // Upcast from f4E2M1FN to f32.
+    %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
+
+    // Element-wise upcast with broadcast (blockSize = 32).
+    %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+    %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
+    ```
   }];
   let hasVerifier = 1;
   let assemblyFormat =
@@ -1406,6 +1417,17 @@ def Arith_ScalingTruncFOp
     result = arith.divf(input, scale.extf)
     result.cast = arith.truncf(result, resultTy)
     ```
+
+    Example:
+
+    ```mlir
+    // Downcast from f32 to f4E2M1FN.
+    %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
+
+    // Element-wise downcast with broadcast (blockSize = 32).
+    %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+    %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
+    ```
   }];
   let hasVerifier = 1;
   let assemblyFormat =

``````````

</details>


https://github.com/llvm/llvm-project/pull/163980


More information about the Mlir-commits mailing list