[Mlir-commits] [mlir] [mlir][arith][nfc] Adding examples to scaling_extf/truncf descriptions (PR #163980)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 17 08:41:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Muzammil (Muzammiluddin-Syed-ECE)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/163980.diff
1 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+41-19)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..dcac54e074c4d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,25 +1229,25 @@ def Arith_ScalingExtFOp
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
- This operation upcasts input floating-point values using provided scale
- values. It expects both scales and the input operand to be of the same shape,
- making the operation elementwise. Scales are usually calculated per block
- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
-
- If scales are calculated per block where blockSize != 1, then scales may
- require broadcasting to make this operation elementwise. For example, let's
- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
- `<1 x 1 x ... (dimN/blockSize) x 1>`.
-
- In this example, before calling into `arith.scaling_extf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
- `arith.scaling_extf` would perform the following:
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_extf` would perform the following:
```
resultTy = get_type(result)
@@ -1260,6 +1260,17 @@ def Arith_ScalingExtFOp
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
+
+ Example:
+
+ ```mlir
+ // Upcast from f4E2M1FN to f32.
+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
+
+ // Element-wise upcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
@@ -1406,6 +1417,17 @@ def Arith_ScalingTruncFOp
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultTy)
```
+
+ Example:
+
+ ```mlir
+ // Downcast from f32 to f4E2M1FN.
+ %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
+
+ // Element-wise downcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
``````````
</details>
https://github.com/llvm/llvm-project/pull/163980
More information about the Mlir-commits
mailing list