[Mlir-commits] [mlir] [mlir][arith][nfc] Adding examples to scaling_extf/truncf descriptions (PR #163980)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 17 09:11:30 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/163980
>From cd2368dbd99fe535399eab980c6dfdd8f3641a0c Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 17 Oct 2025 10:25:02 -0500
Subject: [PATCH 1/2] Adding examples to scaling_extf/truncf descriptions
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 60 +++++++++++++------
1 file changed, 41 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..dcac54e074c4d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,25 +1229,25 @@ def Arith_ScalingExtFOp
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
- This operation upcasts input floating-point values using provided scale
- values. It expects both scales and the input operand to be of the same shape,
- making the operation elementwise. Scales are usually calculated per block
- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
-
- If scales are calculated per block where blockSize != 1, then scales may
- require broadcasting to make this operation elementwise. For example, let's
- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
- `<1 x 1 x ... (dimN/blockSize) x 1>`.
-
- In this example, before calling into `arith.scaling_extf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
- `arith.scaling_extf` would perform the following:
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_extf` would perform the following:
```
resultTy = get_type(result)
@@ -1260,6 +1260,17 @@ def Arith_ScalingExtFOp
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
+
+ Example:
+
+ ```mlir
+ // Upcast from f4E2M1FN to f32.
+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
+
+ // Element-wise upcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
@@ -1406,6 +1417,17 @@ def Arith_ScalingTruncFOp
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultTy)
```
+
+ Example:
+
+ ```mlir
+ // Downcast from f32 to f4E2M1FN.
+ %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
+
+ // Element-wise downcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
>From eadedab2ec3374f08a3a5cacf68d07fa4bbf8fdc Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 17 Oct 2025 11:11:19 -0500
Subject: [PATCH 2/2] improving non-mlir example
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 36 ++++++++++---------
1 file changed, 20 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index dcac54e074c4d..a38cf41a3e09b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1249,14 +1249,16 @@ def Arith_ScalingExtFOp
that there could be multiple quantization axes. Internally,
`arith.scaling_extf` would perform the following:
- ```
- resultTy = get_type(result)
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
- input.extf = arith.extf(input) : inputTy to resultTy
- result = arith.mulf(scale.extf, input.extf)
+ ```mlir
+ // Cast scale to result type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Cast input to result type.
+ %2 = arith.extf %3 : f4E2M1FN to f16
+
+ // Perform scaling
+ %3 = arith.mulf %2, %1 : f16
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
@@ -1408,14 +1410,16 @@ def Arith_ScalingTruncFOp
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:
- ```
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- resultTy = get_type(result)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
- result = arith.divf(input, scale.extf)
- result.cast = arith.truncf(result, resultTy)
+ ```mlir
+ // Cast scale to input type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Perform scaling.
+ %3 = arith.divf %2, %1 : f16
+
+ // Cast to result type.
+ %4 = arith.truncf %3 : f16 to f4E2M1FN
```
Example:
More information about the Mlir-commits
mailing list