[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Umang Yadav llvmlistbot at llvm.org
Thu May 29 08:46:10 PDT 2025


https://github.com/umangyadav updated https://github.com/llvm/llvm-project/pull/141965

>From 1ed7462091f21b80d8b67a6543809eabc8ef8149 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 19:51:22 +0000
Subject: [PATCH 01/12] Make it elementwise op

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 81 +++++++++++++++++++
 .../mlir/Dialect/Arith/Transforms/Passes.h    |  3 +
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 26 ++++++
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 47 ++++++++++-
 4 files changed, 155 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 599b3b982ec7f..ba62f3b7730b3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingExtFOp :
+    Arith_Op<"scaling_extf",
+      [Pure, 
+       SameInputOutputTensorDims, 
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   FloatLike:$scale,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+    Results<(outs FloatLike:$out)> {
+  let summary = "cast from floating-point to larger floating-point using provided scales";
+  let description = [{
+    Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
+    Scale operand is usually expected to be of type f8E8M0 but that is not strictly necessary. 
+    Scale is usually calculated per block. 
+    It is assumed that Scale operand is broadcasted appropariately to make it of same shape as Input operand so that `arith.scaling_extf` an elementwise op.
+    ```
+    resultTy = get_type(result) 
+    scaleTy  = get_type(scale)
+    inputTy = get_type(input)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.bcast = broadcast_to_same_shape_as(result)
+    scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
+    input.extf = arith.extf(input) : inputTy to resultTy
+    result = arith.mulf(scale.extf, input.extf)
+    ```
+    It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1280,6 +1318,49 @@ def Arith_TruncFOp :
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp :
+    Arith_Op<"scaling_truncf",
+      [Pure, 
+       SameInputOutputTensorDims, 
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   FloatLike:$scale,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+    Results<(outs FloatLike:$out)> {
+  let summary = "cast from floating-point to narrower floating-point with scales";
+  let description = [{
+    This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
+    This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. 
+    Therefore original input of shape `<dim1 x dim2 ... dimN>` can be thought of as of shape `<dim1 x dim2 x ... (dimN / blockSize) x blockSize>`, 
+    assuming quantization axis is the last axis. 
+    Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`. 
+    `arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then 
+    scales must be broadcasted appropariately to make it of same shape as the input operand.
+    Internally arith.scaling_truncf does the following:
+    ```
+    scaleETy = get_type(scale)
+    inputETy = get_type(input)
+    resultETy = get_type(result)
+    scale.bcast = broadcast_to_same_shape_as(input)
+    scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
+    scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
+    result = arith.divf(input, scale.extf)
+    result.cast = arith.truncf(result, resultETy)
+    ```
+    OCP MXFP spec flushes denorm input value before quantization. NaNs are propagated. 
+
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 5aaac8d8e3dc5..bba335431aee6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expland scaling ExtF/TruncF ops to equivalent arith ops
+void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 41f2d0f3425e2..9e53e195274aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 
 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 
+//===----------------------------------------------------------------------===//
+// ScalingExtFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
+                                             TypeRange outputs) {
+  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingExtFOp::verify() {
+  return verifyExtOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
   return verifyTruncateOp<FloatType>(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// ScalingTruncFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
+                                               TypeRange outputs) {
+  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingTruncFOp::verify() {
+  return verifyTruncateOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 95546bb09e765..bf8d5e434eee0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -409,6 +409,40 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    Type resultTy = op.getType();
+    Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
+    Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
+    Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalingTruncFOpConverter
+    : public OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    Type resultTy = op.getType();
+    Type inputTy = inputOperand.getType();
+    Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
+    Value result = b.create<arith::DivFOp>(inputOperand, scaleExt);
+    Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
+    rewriter.replaceOp(op, resultCast);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +466,9 @@ struct ArithExpandOpsPass
       arith::MaximumFOp,
       arith::MinimumFOp,
       arith::MaxNumFOp,
-      arith::MinNumFOp
+      arith::MinNumFOp,
+      arith::ScalingExtFOp,
+      arith::ScalingTruncFOp
     >();
 
     if (includeBf16) {
@@ -492,8 +528,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandScalingExtTruncPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
+  populateExpandScalingExtTruncPatterns(patterns);
   // clang-format off
   patterns.add<
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +546,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
     MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
     MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
-    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
+    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> 
    >(patterns.getContext());
   // clang-format on
 }

>From 91bb889b88cce6584a566f308ad1331f7b657cfb Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 20:05:18 +0000
Subject: [PATCH 02/12] Add flushing logic

---
 mlir/include/mlir/IR/Builders.h                 |  1 +
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 13 ++++++++++++-
 mlir/lib/IR/Builders.cpp                        |  2 ++
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3f7b3268dd085..d68dbdb1efeef 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getF8E8M0Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index bf8d5e434eee0..203c87fa9ec76 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -436,7 +436,18 @@ struct ScalingTruncFOpConverter
     Type resultTy = op.getType();
     Type inputTy = inputOperand.getType();
     Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
-    Value result = b.create<arith::DivFOp>(inputOperand, scaleExt);
+    // flush denorms, check if exponent part of input operand is zero or not.
+    Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
+    Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
+    Value inputExponent = b.create<arith::TruncFOp>(inputOperand, f8E8M0Ty);
+    Value inputExponentU8 = b.create<arith::BitcastOp>(inputExponent, i8Ty);
+    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
+                                            inputExponentU8);
+    Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
+    Value flushedInput =
+        b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
+    Value result = b.create<arith::DivFOp>(flushedInput, scaleExt);
     Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
     rewriter.replaceOp(op, resultCast);
     return success();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 89102115cdc40..5f7bc50afc418 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
+
 FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
 
 FloatType Builder::getF16Type() { return Float16Type::get(context); }

>From 8eebbea9d7848c6a1db5dc3369f8482cc5102fcc Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 20:09:14 +0000
Subject: [PATCH 03/12] Fix build issues

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 203c87fa9ec76..2ab498b22a3b4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -439,8 +439,8 @@ struct ScalingTruncFOpConverter
     // flush denorms, check if exponent part of input operand is zero or not.
     Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
     Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
-    Value inputExponent = b.create<arith::TruncFOp>(inputOperand, f8E8M0Ty);
-    Value inputExponentU8 = b.create<arith::BitcastOp>(inputExponent, i8Ty);
+    Value inputExponent = b.create<arith::TruncFOp>(f8E8M0Ty, inputOperand);
+    Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
                                             inputExponentU8);

>From acc66584077de053b7f74044b32d63c5d8a75250 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 00:43:14 +0000
Subject: [PATCH 04/12] clamping on exponent

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 66 ++++++++++++++++---
 1 file changed, 58 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2ab498b22a3b4..dfec877f2535a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -6,13 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include <cstdint>
 
 namespace mlir {
 namespace arith {
@@ -417,6 +420,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
     Type resultTy = op.getType();
+    // extf on scale will essentially create f32 number that is 2^scale
     Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
     Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
     Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
@@ -433,13 +437,58 @@ struct ScalingTruncFOpConverter
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
+    auto scaleTy = scaleOperand.getType();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+    }
+    auto scaleETy = getElementTypeOrSelf(scaleOperand);
     Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(op.getOut());
+
     Type inputTy = inputOperand.getType();
-    Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
-    // flush denorms, check if exponent part of input operand is zero or not.
-    Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
-    Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
-    Value inputExponent = b.create<arith::TruncFOp>(f8E8M0Ty, inputOperand);
+    Type inputETy = getElementTypeOrSelf(inputOperand);
+    if (!inputETy.isF32()) {
+      inputOperand = b.create<arith::ExtFOp>(b.getF32Type(), inputOperand);
+      inputETy = getElementTypeOrSelf(inputOperand);
+    }
+
+    Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
+    Type f8Ty = cloneToShapedType(scaleTy, b.getF8E8M0Type());
+
+    // normalize scale by exponent of the max normal value in result type as per
+    // the OCP MXFP spec
+    const llvm::fltSemantics &resultFltSemantics =
+        llvm::cast<FloatType>(resultETy).getFloatSemantics();
+    int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
+    Value cMaxNormalExponent =
+        createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+    Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
+    Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
+    Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
+    Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
+    Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
+    Value normalizedUnbiasedScale =
+        b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
+    // clamp scale exponent
+    Value clampUpperCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
+    Value clampLowerCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
+    Value clampedScale = b.create<arith::SelectOp>(
+        clampUpperCond, c127,
+        b.create<arith::SelectOp>(clampLowerCond, cNeg127,
+                                  normalizedUnbiasedScale),
+        normalizedUnbiasedScale);
+    Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
+    Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
+    Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
+    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
+    // flush denorms, for that check if exponent part of input operand is zero
+    // or not.
+    Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
@@ -447,7 +496,8 @@ struct ScalingTruncFOpConverter
     Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
-    Value result = b.create<arith::DivFOp>(flushedInput, scaleExt);
+    Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
+    // TODO check if any sort of clamping is required or not
     Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
     rewriter.replaceOp(op, resultCast);
     return success();

>From 6797446099b124f0dfecad1b05fd4703f868e75e Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 00:51:33 +0000
Subject: [PATCH 05/12] propagate rounding mode and fast math attrs

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dfec877f2535a..b8801f43b7032 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -460,6 +460,7 @@ struct ScalingTruncFOpConverter
 
     // normalize scale by exponent of the max normal value in result type as per
     // the OCP MXFP spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
     const llvm::fltSemantics &resultFltSemantics =
         llvm::cast<FloatType>(resultETy).getFloatSemantics();
     int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
@@ -472,7 +473,11 @@ struct ScalingTruncFOpConverter
     Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
     Value normalizedUnbiasedScale =
         b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
-    // clamp scale exponent
+    // clamp scale exponent as per spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
+    // upper limit of 127 will be mapped to biased value of 255 and will be
+    // bitcasted to 0xFF in F8E8M0 which will be converted to f32 NaNs using
+    // extf
     Value clampUpperCond = b.create<arith::CmpIOp>(
         arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
     Value clampLowerCond = b.create<arith::CmpIOp>(
@@ -480,13 +485,12 @@ struct ScalingTruncFOpConverter
     Value clampedScale = b.create<arith::SelectOp>(
         clampUpperCond, c127,
         b.create<arith::SelectOp>(clampLowerCond, cNeg127,
-                                  normalizedUnbiasedScale),
-        normalizedUnbiasedScale);
+                                  normalizedUnbiasedScale));
     Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
     Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
     Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
     Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
-    // flush denorms, for that check if exponent part of input operand is zero
+    // flush denorms by checking if exponent part of input operand is zero
     // or not.
     Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
@@ -497,8 +501,9 @@ struct ScalingTruncFOpConverter
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
     Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
-    // TODO check if any sort of clamping is required or not
-    Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
+    // propagate rounding mode and fast math attributes
+    Value resultCast = b.create<arith::TruncFOp>(
+        resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
     rewriter.replaceOp(op, resultCast);
     return success();
   }

>From 3ad83bdec10474a42d9249aa1d0a6a178f3ef7db Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 01:07:35 +0000
Subject: [PATCH 06/12] Add some more notes

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 15 +++++-----
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 28 ++++++++++++-------
 2 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ba62f3b7730b3..cc8b536a29734 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1218,25 +1218,26 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 //===----------------------------------------------------------------------===//
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
-
+// TODO Remove rouding mode attr for EXtf
 def Arith_ScalingExtFOp :
     Arith_Op<"scaling_extf",
       [Pure, 
        SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
        DeclareOpInterfaceMethods<ArithFastMathInterface>, 
        DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins FloatLike:$in,
                    FloatLike:$scale,
-                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
                    OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
     Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to larger floating-point using provided scales";
   let description = [{
     Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
-    Scale operand is usually expected to be of type f8E8M0 but that is not strictly necessary. 
-    Scale is usually calculated per block. 
-    It is assumed that Scale operand is broadcasted appropariately to make it of same shape as Input operand so that `arith.scaling_extf` an elementwise op.
+    Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
+    Scale is usually calculated per block.  
+    Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. 
+    Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. 
+    Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.  
+    In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
     ```
     resultTy = get_type(result) 
     scaleTy  = get_type(scale)
@@ -1250,7 +1251,7 @@ def Arith_ScalingExtFOp :
     It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat = [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index b8801f43b7032..45468df10fb37 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -419,8 +419,13 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling extf is not using scale operand of type f8E8M0FNU");
+    }
     Type resultTy = op.getType();
-    // extf on scale will essentially create f32 number that is 2^scale
+    // extf on scale will essentially create f32 number that is 2^scale and will
+    // also propagate NaNs
     Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
     Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
     Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
@@ -437,26 +442,29 @@ struct ScalingTruncFOpConverter
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
-    auto scaleTy = scaleOperand.getType();
     if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");
     }
     auto scaleETy = getElementTypeOrSelf(scaleOperand);
+
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(op.getOut());
 
     Type inputTy = inputOperand.getType();
     Type inputETy = getElementTypeOrSelf(inputOperand);
-    if (!inputETy.isF32()) {
-      inputOperand = b.create<arith::ExtFOp>(b.getF32Type(), inputOperand);
-      inputETy = getElementTypeOrSelf(inputOperand);
-    }
 
     Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
     Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
-    Type f8Ty = cloneToShapedType(scaleTy, b.getF8E8M0Type());
+    Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
+
+    if (inputETy.getIntOrFloatBitWidth() < 32) {
+      inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
+    } else if (inputETy.getIntOrFloatBitWidth() > 32) {
+      inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
+    }
+    inputETy = getElementTypeOrSelf(inputOperand);
 
     // normalize scale by exponent of the max normal value in result type as per
     // the OCP MXFP spec
@@ -475,9 +483,9 @@ struct ScalingTruncFOpConverter
         b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
     // clamp scale exponent as per spec
     // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
-    // upper limit of 127 will be mapped to biased value of 255 and will be
-    // bitcasted to 0xFF in F8E8M0 which will be converted to f32 NaNs using
-    // extf
+    // upper clamp limit of 127 will be mapped to biased value of 255 and will
+    // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
+    // using arith.extf
     Value clampUpperCond = b.create<arith::CmpIOp>(
         arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
     Value clampLowerCond = b.create<arith::CmpIOp>(

>From 5e49a7207928e662e5cf43eae1b7997041cbe3f6 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 13:27:32 +0000
Subject: [PATCH 07/12] add scaling_extf tests

---
 mlir/test/Dialect/Arith/expand-ops.mlir | 63 ++++++++++++++++++++++++-
 1 file changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 5b6badf13d763..79930efdb9453 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s 
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -374,7 +375,67 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
 
 // CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
 // CHECK-NOT: arith.extf
+// CHECK: return
 
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+    return %0 : f32 
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+    return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_bf16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> 
+// SCHECK: return %[[RESULT]]
 
 // -----
 

>From 682573e9290d080c77d6bdbe7a63a060ff040e68 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 13:49:40 +0000
Subject: [PATCH 08/12] Fix some issues

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 17 +++++++++++++----
 mlir/test/Dialect/Arith/expand-ops.mlir         |  2 +-
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 45468df10fb37..524a77e7a153c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -26,6 +26,16 @@ namespace arith {
 
 using namespace mlir;
 
+static Value createFloatConst(Location loc, Type type, float value,
+                              PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
@@ -34,7 +44,6 @@ static Value createConst(Location loc, Type type, int value,
     return rewriter.create<arith::ConstantOp>(
         loc, DenseElementsAttr::get(shapedTy, attr));
   }
-
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
@@ -446,7 +455,7 @@ struct ScalingTruncFOpConverter
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");
     }
-    auto scaleETy = getElementTypeOrSelf(scaleOperand);
+    auto scaleTy = scaleOperand.getType();
 
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(op.getOut());
@@ -500,12 +509,12 @@ struct ScalingTruncFOpConverter
     Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
     // flush denorms by checking if exponent part of input operand is zero
     // or not.
-    Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
+    Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
                                             inputExponentU8);
-    Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
+    Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
     Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 79930efdb9453..59483bf2731dc 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -307,8 +307,8 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
 // CHECK-NOT: arith.truncf
 
-
 // -----
+
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f32
     return %0 : f32

>From de4497b3277e8408e2187e1924a4bb56f6bc7a61 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 14:49:14 +0000
Subject: [PATCH 09/12] add test for scaling_truncf

---
 mlir/test/Dialect/Arith/expand-ops.mlir | 33 +++++++++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 59483bf2731dc..23700edea159b 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -306,7 +306,40 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 
 // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
 // CHECK-NOT: arith.truncf
+// CHECK: return
+
+// -----
 
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// SCHECK: %[[C2:.+]] = arith.constant 2 : i32
+// SCHECK: %[[C127:.+]] = arith.constant 127 : i32
+// SCHECK: %[[CN127:.+]] = arith.constant -127 : i32
+// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : f8E8M0FNU to i8
+// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : i8 to i32
+// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] : i32
+// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : i32
+// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : i32
+// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : i32
+// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : i32
+// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : i32
+// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : i32
+// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : i32 to i8
+// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : i8 to f8E8M0FNU
+// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : f8E8M0FNU to f32
+// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %arg0 : f32 to f8E8M0FNU
+// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : f8E8M0FNU to i8
+// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
+// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : i8
+// SCHECK: %[[CF0:.+]] = arith.constant 0.000000e+00 : f32
+// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %arg0 : f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : f32
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// SCHECK: return %[[RESULT]]
 // -----
 
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {

>From e239157087dc9d6d878c9a48e98652536935fc02 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:23:53 +0000
Subject: [PATCH 10/12] add some more tests

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    |  1 +
 mlir/test/Dialect/Arith/expand-ops.mlir       | 53 +++++++++++++++++++
 2 files changed, 54 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 524a77e7a153c..8ad78cd74d8ca 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -473,6 +473,7 @@ struct ScalingTruncFOpConverter
     } else if (inputETy.getIntOrFloatBitWidth() > 32) {
       inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
     }
+    inputTy = inputOperand.getType();
     inputETy = getElementTypeOrSelf(inputOperand);
 
     // normalize scale by exponent of the max normal value in result type as per
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 23700edea159b..87f243d3cbdb4 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -340,6 +340,59 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
 // SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : f32
 // SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
 // SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// SCHECK: %[[INPUTF32:.+]] = arith.extf %arg0 : vector<4xf16> to vector<4xf32>
+// SCHECK: %[[C2:.+]] = arith.constant dense<4> : vector<4xi32>
+// SCHECK: %[[C127:.+]] = arith.constant dense<127> : vector<4xi32>
+// SCHECK: %[[CN127:.+]] = arith.constant dense<-127> : vector<4xi32>
+// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : vector<4xf8E8M0FNU> to vector<4xi8>
+// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : vector<4xi8> to vector<4xi32>
+// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] :  vector<4xi32>
+// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : vector<4xi32>
+// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : vector<4xi32>
+// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : vector<4xi32>
+// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : vector<4xi1>, vector<4xi32>
+// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : vector<4xi1>, vector<4xi32>
+// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : vector<4xi32>
+// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : vector<4xi32> to vector<4xi8>
+// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : vector<4xi8> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %[[INPUTF32]] : vector<4xf32> to vector<4xf8E8M0FNU>
+// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : vector<4xf8E8M0FNU> to vector<4xi8> 
+// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
+// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : vector<4xi8>
+// SCHECK: %[[CF0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %[[INPUTF32]] : vector<4xi1>, vector<4xf32>
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+// SCHECK-LABLE: @scaling_truncf_propagate_rounding_mode
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
 // -----
 
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {

>From 646465cebd3265a118451a4cdce7d84ea84ff84b Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:24:59 +0000
Subject: [PATCH 11/12] Fix Formatting

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 52 +++++++++----------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index cc8b536a29734..34ef1cb19c0d9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1219,17 +1219,16 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
 // TODO Remove rouding mode attr for EXtf
-def Arith_ScalingExtFOp :
-    Arith_Op<"scaling_extf",
-      [Pure, 
-       SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
-       DeclareOpInterfaceMethods<CastOpInterface>]>,
-    Arguments<(ins FloatLike:$in,
-                   FloatLike:$scale,
-                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
-    Results<(outs FloatLike:$out)> {
-  let summary = "cast from floating-point to larger floating-point using provided scales";
+def Arith_ScalingExtFOp
+    : Arith_Op<
+          "scaling_extf", [Pure, SameInputOutputTensorDims,
+                           DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                           DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to larger floating-point using provided scales";
   let description = [{
     Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
     Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
@@ -1251,7 +1250,8 @@ def Arith_ScalingExtFOp :
     It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat =
+      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1323,19 +1323,18 @@ def Arith_TruncFOp :
 // Scaling TruncFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ScalingTruncFOp :
-    Arith_Op<"scaling_truncf",
-      [Pure, 
-       SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
-       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
-       DeclareOpInterfaceMethods<CastOpInterface>]>,
-    Arguments<(ins FloatLike:$in,
-                   FloatLike:$scale,
-                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
-                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
-    Results<(outs FloatLike:$out)> {
-  let summary = "cast from floating-point to narrower floating-point with scales";
+def Arith_ScalingTruncFOp
+    : Arith_Op<"scaling_truncf",
+               [Pure, SameInputOutputTensorDims,
+                DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+                DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to narrower floating-point with scales";
   let description = [{
     This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
     This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. 
@@ -1359,7 +1358,8 @@ def Arith_ScalingTruncFOp :
 
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat =
+      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//

>From 80c080fcb797a58bd5aab4e8322c5436c962f6df Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:45:47 +0000
Subject: [PATCH 12/12] Remove TODO

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 34ef1cb19c0d9..5d2d545cf9234 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1218,7 +1218,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 //===----------------------------------------------------------------------===//
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
-// TODO Remove rouding mode attr for EXtf
 def Arith_ScalingExtFOp
     : Arith_Op<
           "scaling_extf", [Pure, SameInputOutputTensorDims,



More information about the Mlir-commits mailing list