[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu May 29 12:07:10 PDT 2025


================
@@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+def Arith_ScalingExtFOp
+    : Arith_Op<
+          "scaling_extf", [Pure, SameInputOutputTensorDims,
+                           DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                           DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to larger floating-point using provided scales";
+  let description = [{
+    Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
+    Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
+    Scale is usually calculated per block.  
+    Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. 
+    Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. 
+    Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.  
+    In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
+    ```
+    resultTy = get_type(result) 
+    scaleTy  = get_type(scale)
+    inputTy = get_type(input)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
----------------
krzysz00 wrote:

There are reasons to not fix it

The main one is hardware intrinsics that take f32 scales (and then ignore the sign and mantissa), where we'd like to avoid redundant conversions to f8E8M0 that aren't needed

https://github.com/llvm/llvm-project/pull/141965


More information about the Mlir-commits mailing list