[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Prashant Kumar llvmlistbot at llvm.org
Thu May 29 11:54:36 PDT 2025


================
@@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+def Arith_ScalingExtFOp
+    : Arith_Op<
+          "scaling_extf", [Pure, SameInputOutputTensorDims,
+                           DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                           DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to larger floating-point using provided scales";
+  let description = [{
+    Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
+    Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
+    Scale is usually calculated per block.  
+    Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. 
+    Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. 
+    Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.  
+    In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
+    ```
+    resultTy = get_type(result) 
+    scaleTy  = get_type(scale)
+    inputTy = get_type(input)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.bcast = broadcast_to_same_shape_as(result)
+    scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
+    input.extf = arith.extf(input) : inputTy to resultTy
+    result = arith.mulf(scale.extf, input.extf)
+    ```
+    It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
----------------
pashu123 wrote:

```suggestion
    It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN.
```

https://github.com/llvm/llvm-project/pull/141965


More information about the Mlir-commits mailing list