[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)

Umang Yadav llvmlistbot at llvm.org
Mon Jun 9 08:24:29 PDT 2025


https://github.com/umangyadav updated https://github.com/llvm/llvm-project/pull/141965

>From 1ed7462091f21b80d8b67a6543809eabc8ef8149 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 19:51:22 +0000
Subject: [PATCH 01/30] Make it elementwise op

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 81 +++++++++++++++++++
 .../mlir/Dialect/Arith/Transforms/Passes.h    |  3 +
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 26 ++++++
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 47 ++++++++++-
 4 files changed, 155 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 599b3b982ec7f..ba62f3b7730b3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingExtFOp :
+    Arith_Op<"scaling_extf",
+      [Pure, 
+       SameInputOutputTensorDims, 
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   FloatLike:$scale,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+    Results<(outs FloatLike:$out)> {
+  let summary = "cast from floating-point to larger floating-point using provided scales";
+  let description = [{
+    Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
+    Scale operand is usually expected to be of type f8E8M0 but that is not strictly necessary. 
+    Scale is usually calculated per block. 
+    It is assumed that Scale operand is broadcasted appropariately to make it of same shape as Input operand so that `arith.scaling_extf` an elementwise op.
+    ```
+    resultTy = get_type(result) 
+    scaleTy  = get_type(scale)
+    inputTy = get_type(input)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.bcast = broadcast_to_same_shape_as(result)
+    scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
+    input.extf = arith.extf(input) : inputTy to resultTy
+    result = arith.mulf(scale.extf, input.extf)
+    ```
+    It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1280,6 +1318,49 @@ def Arith_TruncFOp :
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp :
+    Arith_Op<"scaling_truncf",
+      [Pure, 
+       SameInputOutputTensorDims, 
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   FloatLike:$scale,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+    Results<(outs FloatLike:$out)> {
+  let summary = "cast from floating-point to narrower floating-point with scales";
+  let description = [{
+    This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
+    This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. 
+    Therefore original input of shape `<dim1 x dim2 ... dimN>` can be thought of as of shape `<dim1 x dim2 x ... (dimN / blockSize) x blockSize>`, 
+    assuming quantization axis is the last axis. 
+    Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`. 
+    `arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then 
+    scales must be broadcasted appropariately to make it of same shape as the input operand.
+    Internally arith.scaling_truncf does the following:
+    ```
+    scaleETy = get_type(scale)
+    inputETy = get_type(input)
+    resultETy = get_type(result)
+    scale.bcast = broadcast_to_same_shape_as(input)
+    scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
+    scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
+    result = arith.divf(input, scale.extf)
+    result.cast = arith.truncf(result, resultETy)
+    ```
+    OCP MXFP spec flushes denorm input value before quantization. NaNs are propagated. 
+
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 5aaac8d8e3dc5..bba335431aee6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expland scaling ExtF/TruncF ops to equivalent arith ops
+void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 41f2d0f3425e2..9e53e195274aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 
 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 
+//===----------------------------------------------------------------------===//
+// ScalingExtFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
+                                             TypeRange outputs) {
+  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingExtFOp::verify() {
+  return verifyExtOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
   return verifyTruncateOp<FloatType>(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// ScalingTruncFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
+                                               TypeRange outputs) {
+  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingTruncFOp::verify() {
+  return verifyTruncateOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 95546bb09e765..bf8d5e434eee0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -409,6 +409,40 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    Type resultTy = op.getType();
+    Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
+    Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
+    Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalingTruncFOpConverter
+    : public OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto inputOperand = op.getIn();
+    auto scaleOperand = op.getScale();
+    Type resultTy = op.getType();
+    Type inputTy = inputOperand.getType();
+    Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
+    Value result = b.create<arith::DivFOp>(inputOperand, scaleExt);
+    Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
+    rewriter.replaceOp(op, resultCast);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +466,9 @@ struct ArithExpandOpsPass
       arith::MaximumFOp,
       arith::MinimumFOp,
       arith::MaxNumFOp,
-      arith::MinNumFOp
+      arith::MinNumFOp,
+      arith::ScalingExtFOp,
+      arith::ScalingTruncFOp
     >();
 
     if (includeBf16) {
@@ -492,8 +528,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandScalingExtTruncPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
+  populateExpandScalingExtTruncPatterns(patterns);
   // clang-format off
   patterns.add<
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +546,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
     MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
     MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
-    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
+    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> 
    >(patterns.getContext());
   // clang-format on
 }

>From 91bb889b88cce6584a566f308ad1331f7b657cfb Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 20:05:18 +0000
Subject: [PATCH 02/30] Add flushing logic

---
 mlir/include/mlir/IR/Builders.h                 |  1 +
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 13 ++++++++++++-
 mlir/lib/IR/Builders.cpp                        |  2 ++
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3f7b3268dd085..d68dbdb1efeef 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getF8E8M0Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index bf8d5e434eee0..203c87fa9ec76 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -436,7 +436,18 @@ struct ScalingTruncFOpConverter
     Type resultTy = op.getType();
     Type inputTy = inputOperand.getType();
     Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
-    Value result = b.create<arith::DivFOp>(inputOperand, scaleExt);
+    // flush denorms, check if exponent part of input operand is zero or not.
+    Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
+    Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
+    Value inputExponent = b.create<arith::TruncFOp>(inputOperand, f8E8M0Ty);
+    Value inputExponentU8 = b.create<arith::BitcastOp>(inputExponent, i8Ty);
+    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
+                                            inputExponentU8);
+    Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
+    Value flushedInput =
+        b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
+    Value result = b.create<arith::DivFOp>(flushedInput, scaleExt);
     Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
     rewriter.replaceOp(op, resultCast);
     return success();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 89102115cdc40..5f7bc50afc418 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
+
 FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
 
 FloatType Builder::getF16Type() { return Float16Type::get(context); }

>From 8eebbea9d7848c6a1db5dc3369f8482cc5102fcc Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Wed, 28 May 2025 20:09:14 +0000
Subject: [PATCH 03/30] Fix build issues

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 203c87fa9ec76..2ab498b22a3b4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -439,8 +439,8 @@ struct ScalingTruncFOpConverter
     // flush denorms, check if exponent part of input operand is zero or not.
     Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
     Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
-    Value inputExponent = b.create<arith::TruncFOp>(inputOperand, f8E8M0Ty);
-    Value inputExponentU8 = b.create<arith::BitcastOp>(inputExponent, i8Ty);
+    Value inputExponent = b.create<arith::TruncFOp>(f8E8M0Ty, inputOperand);
+    Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
                                             inputExponentU8);

>From acc66584077de053b7f74044b32d63c5d8a75250 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 00:43:14 +0000
Subject: [PATCH 04/30] clamping on exponent

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 66 ++++++++++++++++---
 1 file changed, 58 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2ab498b22a3b4..dfec877f2535a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -6,13 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include <cstdint>
 
 namespace mlir {
 namespace arith {
@@ -417,6 +420,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
     Type resultTy = op.getType();
+    // extf on scale will essentially create f32 number that is 2^scale
     Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
     Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
     Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
@@ -433,13 +437,58 @@ struct ScalingTruncFOpConverter
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
+    auto scaleTy = scaleOperand.getType();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+    }
+    auto scaleETy = getElementTypeOrSelf(scaleOperand);
     Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(op.getOut());
+
     Type inputTy = inputOperand.getType();
-    Value scaleExt = b.create<arith::ExtFOp>(inputTy, scaleOperand);
-    // flush denorms, check if exponent part of input operand is zero or not.
-    Type f8E8M0Ty = cloneToShapedType(inputTy, b.getF8E8M0Type());
-    Type i8Ty = cloneToShapedType(inputTy, b.getI8Type());
-    Value inputExponent = b.create<arith::TruncFOp>(f8E8M0Ty, inputOperand);
+    Type inputETy = getElementTypeOrSelf(inputOperand);
+    if (!inputETy.isF32()) {
+      inputOperand = b.create<arith::ExtFOp>(b.getF32Type(), inputOperand);
+      inputETy = getElementTypeOrSelf(inputOperand);
+    }
+
+    Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
+    Type f8Ty = cloneToShapedType(scaleTy, b.getF8E8M0Type());
+
+    // normalize scale by exponent of the max normal value in result type as per
+    // the OCP MXFP spec
+    const llvm::fltSemantics &resultFltSemantics =
+        llvm::cast<FloatType>(resultETy).getFloatSemantics();
+    int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
+    Value cMaxNormalExponent =
+        createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+    Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
+    Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
+    Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
+    Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
+    Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
+    Value normalizedUnbiasedScale =
+        b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
+    // clamp scale exponent
+    Value clampUpperCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
+    Value clampLowerCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
+    Value clampedScale = b.create<arith::SelectOp>(
+        clampUpperCond, c127,
+        b.create<arith::SelectOp>(clampLowerCond, cNeg127,
+                                  normalizedUnbiasedScale),
+        normalizedUnbiasedScale);
+    Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
+    Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
+    Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
+    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
+    // flush denorms, for that check if exponent part of input operand is zero
+    // or not.
+    Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
@@ -447,7 +496,8 @@ struct ScalingTruncFOpConverter
     Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
-    Value result = b.create<arith::DivFOp>(flushedInput, scaleExt);
+    Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
+    // TODO check if any sort of clamping is required or not
     Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
     rewriter.replaceOp(op, resultCast);
     return success();

>From 6797446099b124f0dfecad1b05fd4703f868e75e Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 00:51:33 +0000
Subject: [PATCH 05/30] propagate rounding mode and fast math attrs

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dfec877f2535a..b8801f43b7032 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -460,6 +460,7 @@ struct ScalingTruncFOpConverter
 
     // normalize scale by exponent of the max normal value in result type as per
     // the OCP MXFP spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
     const llvm::fltSemantics &resultFltSemantics =
         llvm::cast<FloatType>(resultETy).getFloatSemantics();
     int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
@@ -472,7 +473,11 @@ struct ScalingTruncFOpConverter
     Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
     Value normalizedUnbiasedScale =
         b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
-    // clamp scale exponent
+    // clamp scale exponent as per spec
+    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
+    // upper limit of 127 will be mapped to biased value of 255 and will be
+    // bitcasted to 0xFF in F8E8M0 which will be converted to f32 NaNs using
+    // extf
     Value clampUpperCond = b.create<arith::CmpIOp>(
         arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
     Value clampLowerCond = b.create<arith::CmpIOp>(
@@ -480,13 +485,12 @@ struct ScalingTruncFOpConverter
     Value clampedScale = b.create<arith::SelectOp>(
         clampUpperCond, c127,
         b.create<arith::SelectOp>(clampLowerCond, cNeg127,
-                                  normalizedUnbiasedScale),
-        normalizedUnbiasedScale);
+                                  normalizedUnbiasedScale));
     Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
     Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
     Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
     Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
-    // flush denorms, for that check if exponent part of input operand is zero
+    // flush denorms by checking if exponent part of input operand is zero
     // or not.
     Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
@@ -497,8 +501,9 @@ struct ScalingTruncFOpConverter
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
     Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
-    // TODO check if any sort of clamping is required or not
-    Value resultCast = b.create<arith::TruncFOp>(resultTy, result);
+    // propagate rounding mode and fast math attributes
+    Value resultCast = b.create<arith::TruncFOp>(
+        resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
     rewriter.replaceOp(op, resultCast);
     return success();
   }

>From 3ad83bdec10474a42d9249aa1d0a6a178f3ef7db Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 01:07:35 +0000
Subject: [PATCH 06/30] Add some more notes

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 15 +++++-----
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 28 ++++++++++++-------
 2 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ba62f3b7730b3..cc8b536a29734 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1218,25 +1218,26 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 //===----------------------------------------------------------------------===//
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
-
+// TODO Remove rouding mode attr for EXtf
 def Arith_ScalingExtFOp :
     Arith_Op<"scaling_extf",
       [Pure, 
        SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
        DeclareOpInterfaceMethods<ArithFastMathInterface>, 
        DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins FloatLike:$in,
                    FloatLike:$scale,
-                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
                    OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
     Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to larger floating-point using provided scales";
   let description = [{
     Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
-    Scale operand is usually expected to be of type f8E8M0 but that is not strictly necessary. 
-    Scale is usually calculated per block. 
-    It is assumed that Scale operand is broadcasted appropariately to make it of same shape as Input operand so that `arith.scaling_extf` an elementwise op.
+    Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
+    Scale is usually calculated per block.  
+    Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. 
+    Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. 
+    Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.  
+    In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
     ```
     resultTy = get_type(result) 
     scaleTy  = get_type(scale)
@@ -1250,7 +1251,7 @@ def Arith_ScalingExtFOp :
     It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat = [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index b8801f43b7032..45468df10fb37 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -419,8 +419,13 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
+    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling extf is not using scale operand of type f8E8M0FNU");
+    }
     Type resultTy = op.getType();
-    // extf on scale will essentially create f32 number that is 2^scale
+    // extf on scale will essentially create f32 number that is 2^scale and will
+    // also propagate NaNs
     Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
     Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
     Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
@@ -437,26 +442,29 @@ struct ScalingTruncFOpConverter
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto inputOperand = op.getIn();
     auto scaleOperand = op.getScale();
-    auto scaleTy = scaleOperand.getType();
     if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");
     }
     auto scaleETy = getElementTypeOrSelf(scaleOperand);
+
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(op.getOut());
 
     Type inputTy = inputOperand.getType();
     Type inputETy = getElementTypeOrSelf(inputOperand);
-    if (!inputETy.isF32()) {
-      inputOperand = b.create<arith::ExtFOp>(b.getF32Type(), inputOperand);
-      inputETy = getElementTypeOrSelf(inputOperand);
-    }
 
     Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
     Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
-    Type f8Ty = cloneToShapedType(scaleTy, b.getF8E8M0Type());
+    Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
+
+    if (inputETy.getIntOrFloatBitWidth() < 32) {
+      inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
+    } else if (inputETy.getIntOrFloatBitWidth() > 32) {
+      inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
+    }
+    inputETy = getElementTypeOrSelf(inputOperand);
 
     // normalize scale by exponent of the max normal value in result type as per
     // the OCP MXFP spec
@@ -475,9 +483,9 @@ struct ScalingTruncFOpConverter
         b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
     // clamp scale exponent as per spec
     // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
-    // upper limit of 127 will be mapped to biased value of 255 and will be
-    // bitcasted to 0xFF in F8E8M0 which will be converted to f32 NaNs using
-    // extf
+    // upper clamp limit of 127 will be mapped to biased value of 255 and will
+    // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
+    // using arith.extf
     Value clampUpperCond = b.create<arith::CmpIOp>(
         arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
     Value clampLowerCond = b.create<arith::CmpIOp>(

>From 5e49a7207928e662e5cf43eae1b7997041cbe3f6 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 13:27:32 +0000
Subject: [PATCH 07/30] add scaling_extf tests

---
 mlir/test/Dialect/Arith/expand-ops.mlir | 63 ++++++++++++++++++++++++-
 1 file changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 5b6badf13d763..79930efdb9453 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s 
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -374,7 +375,67 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
 
 // CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
 // CHECK-NOT: arith.extf
+// CHECK: return
 
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+    return %0 : f32 
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+    return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_bf16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> 
+// SCHECK: return %[[RESULT]]
 
 // -----
 

>From 682573e9290d080c77d6bdbe7a63a060ff040e68 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 13:49:40 +0000
Subject: [PATCH 08/30] Fix some issues

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 17 +++++++++++++----
 mlir/test/Dialect/Arith/expand-ops.mlir         |  2 +-
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 45468df10fb37..524a77e7a153c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -26,6 +26,16 @@ namespace arith {
 
 using namespace mlir;
 
+static Value createFloatConst(Location loc, Type type, float value,
+                              PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
@@ -34,7 +44,6 @@ static Value createConst(Location loc, Type type, int value,
     return rewriter.create<arith::ConstantOp>(
         loc, DenseElementsAttr::get(shapedTy, attr));
   }
-
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
@@ -446,7 +455,7 @@ struct ScalingTruncFOpConverter
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");
     }
-    auto scaleETy = getElementTypeOrSelf(scaleOperand);
+    auto scaleTy = scaleOperand.getType();
 
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(op.getOut());
@@ -500,12 +509,12 @@ struct ScalingTruncFOpConverter
     Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
     // flush denorms by checking if exponent part of input operand is zero
     // or not.
-    Value inputExponent = b.create<arith::TruncFOp>(scaleETy, inputOperand);
+    Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
     Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
                                             inputExponentU8);
-    Value inputTyZero = createConst(op.getLoc(), inputTy, 0, rewriter);
+    Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
     Value flushedInput =
         b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
     Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 79930efdb9453..59483bf2731dc 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -307,8 +307,8 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
 // CHECK-NOT: arith.truncf
 
-
 // -----
+
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f32
     return %0 : f32

>From de4497b3277e8408e2187e1924a4bb56f6bc7a61 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 14:49:14 +0000
Subject: [PATCH 09/30] add test for scaling_truncf

---
 mlir/test/Dialect/Arith/expand-ops.mlir | 33 +++++++++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 59483bf2731dc..23700edea159b 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -306,7 +306,40 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 
 // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
 // CHECK-NOT: arith.truncf
+// CHECK: return
+
+// -----
 
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// SCHECK: %[[C2:.+]] = arith.constant 2 : i32
+// SCHECK: %[[C127:.+]] = arith.constant 127 : i32
+// SCHECK: %[[CN127:.+]] = arith.constant -127 : i32
+// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : f8E8M0FNU to i8
+// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : i8 to i32
+// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] : i32
+// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : i32
+// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : i32
+// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : i32
+// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : i32
+// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : i32
+// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : i32
+// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : i32 to i8
+// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : i8 to f8E8M0FNU
+// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : f8E8M0FNU to f32
+// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %arg0 : f32 to f8E8M0FNU
+// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : f8E8M0FNU to i8
+// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
+// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : i8
+// SCHECK: %[[CF0:.+]] = arith.constant 0.000000e+00 : f32
+// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %arg0 : f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : f32
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// SCHECK: return %[[RESULT]]
 // -----
 
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {

>From e239157087dc9d6d878c9a48e98652536935fc02 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:23:53 +0000
Subject: [PATCH 10/30] add some more tests

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    |  1 +
 mlir/test/Dialect/Arith/expand-ops.mlir       | 53 +++++++++++++++++++
 2 files changed, 54 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 524a77e7a153c..8ad78cd74d8ca 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -473,6 +473,7 @@ struct ScalingTruncFOpConverter
     } else if (inputETy.getIntOrFloatBitWidth() > 32) {
       inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
     }
+    inputTy = inputOperand.getType();
     inputETy = getElementTypeOrSelf(inputOperand);
 
     // normalize scale by exponent of the max normal value in result type as per
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 23700edea159b..87f243d3cbdb4 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -340,6 +340,59 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
 // SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : f32
 // SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
 // SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// SCHECK: %[[INPUTF32:.+]] = arith.extf %arg0 : vector<4xf16> to vector<4xf32>
+// SCHECK: %[[C2:.+]] = arith.constant dense<4> : vector<4xi32>
+// SCHECK: %[[C127:.+]] = arith.constant dense<127> : vector<4xi32>
+// SCHECK: %[[CN127:.+]] = arith.constant dense<-127> : vector<4xi32>
+// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : vector<4xf8E8M0FNU> to vector<4xi8>
+// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : vector<4xi8> to vector<4xi32>
+// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] :  vector<4xi32>
+// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : vector<4xi32>
+// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : vector<4xi32>
+// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : vector<4xi32>
+// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : vector<4xi1>, vector<4xi32>
+// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : vector<4xi1>, vector<4xi32>
+// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : vector<4xi32>
+// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : vector<4xi32> to vector<4xi8>
+// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : vector<4xi8> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %[[INPUTF32]] : vector<4xf32> to vector<4xf8E8M0FNU>
+// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : vector<4xf8E8M0FNU> to vector<4xi8> 
+// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
+// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : vector<4xi8>
+// SCHECK: %[[CF0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %[[INPUTF32]] : vector<4xi1>, vector<4xf32>
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+// SCHECK-LABLE: @scaling_truncf_propagate_rounding_mode
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
 // -----
 
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {

>From 646465cebd3265a118451a4cdce7d84ea84ff84b Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:24:59 +0000
Subject: [PATCH 11/30] Fix Formatting

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 52 +++++++++----------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index cc8b536a29734..34ef1cb19c0d9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1219,17 +1219,16 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
 // TODO Remove rouding mode attr for EXtf
-def Arith_ScalingExtFOp :
-    Arith_Op<"scaling_extf",
-      [Pure, 
-       SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
-       DeclareOpInterfaceMethods<CastOpInterface>]>,
-    Arguments<(ins FloatLike:$in,
-                   FloatLike:$scale,
-                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
-    Results<(outs FloatLike:$out)> {
-  let summary = "cast from floating-point to larger floating-point using provided scales";
+def Arith_ScalingExtFOp
+    : Arith_Op<
+          "scaling_extf", [Pure, SameInputOutputTensorDims,
+                           DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                           DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to larger floating-point using provided scales";
   let description = [{
     Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
     Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
@@ -1251,7 +1250,8 @@ def Arith_ScalingExtFOp :
     It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat =
+      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1323,19 +1323,18 @@ def Arith_TruncFOp :
 // Scaling TruncFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ScalingTruncFOp :
-    Arith_Op<"scaling_truncf",
-      [Pure, 
-       SameInputOutputTensorDims, 
-       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
-       DeclareOpInterfaceMethods<ArithFastMathInterface>, 
-       DeclareOpInterfaceMethods<CastOpInterface>]>,
-    Arguments<(ins FloatLike:$in,
-                   FloatLike:$scale,
-                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
-                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
-    Results<(outs FloatLike:$out)> {
-  let summary = "cast from floating-point to narrower floating-point with scales";
+def Arith_ScalingTruncFOp
+    : Arith_Op<"scaling_truncf",
+               [Pure, SameInputOutputTensorDims,
+                DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+                DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary =
+      "cast from floating-point to narrower floating-point with scales";
   let description = [{
     This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
     This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. 
@@ -1359,7 +1358,8 @@ def Arith_ScalingTruncFOp :
 
   }];
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+  let assemblyFormat =
+      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//

>From 80c080fcb797a58bd5aab4e8322c5436c962f6df Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 15:45:47 +0000
Subject: [PATCH 12/30] Remove TODO

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 34ef1cb19c0d9..5d2d545cf9234 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1218,7 +1218,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
 //===----------------------------------------------------------------------===//
 // Scaling ExtFOp
 //===----------------------------------------------------------------------===//
-// TODO Remove rouding mode attr for EXtf
 def Arith_ScalingExtFOp
     : Arith_Op<
           "scaling_extf", [Pure, SameInputOutputTensorDims,

>From b5df10055f2c6429964dc57283f5c98ad6eec2c8 Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Thu, 29 May 2025 15:01:24 -0400
Subject: [PATCH 13/30] Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 5d2d545cf9234..74498314ea88e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1341,7 +1341,7 @@ def Arith_ScalingTruncFOp
     assuming quantization axis is the last axis. 
     Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`. 
     `arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then 
-    scales must be broadcasted appropariately to make it of same shape as the input operand.
+    Scales must be broadcast appropriately to ensure they are of the same shape as the input operand.
     Internally arith.scaling_truncf does the following:
     ```
     scaleETy = get_type(scale)

>From b6589aed2776cd347d40d19984a29dbd117b09ee Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Thu, 29 May 2025 15:01:33 -0400
Subject: [PATCH 14/30] Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 74498314ea88e..18b23fa0eec48 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1246,7 +1246,7 @@ def Arith_ScalingExtFOp
     input.extf = arith.extf(input) : inputTy to resultTy
     result = arith.mulf(scale.extf, input.extf)
     ```
-    It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
+    It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN.
   }];
   let hasVerifier = 1;
   let assemblyFormat =

>From 12c52a6336156e2874059fd12c47ef3584546879 Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Thu, 29 May 2025 15:01:45 -0400
Subject: [PATCH 15/30] Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8ad78cd74d8ca..7406ecd36001c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -449,8 +449,8 @@ struct ScalingTruncFOpConverter
   LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
                                 PatternRewriter &rewriter) const final {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto inputOperand = op.getIn();
-    auto scaleOperand = op.getScale();
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
     if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");

>From b3cadf29d692a27a5eb62bea576647e90cd81ab0 Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Thu, 29 May 2025 15:01:52 -0400
Subject: [PATCH 16/30] Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7406ecd36001c..17376e7bff91e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -426,8 +426,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
   LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
                                 PatternRewriter &rewriter) const final {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto inputOperand = op.getIn();
-    auto scaleOperand = op.getScale();
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
     if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
       return rewriter.notifyMatchFailure(
           op, "scaling extf is not using scale operand of type f8E8M0FNU");

>From fc907803c42f024801a89edfc825ceb65fef9416 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 29 May 2025 19:37:54 +0000
Subject: [PATCH 17/30] Allow implicit truncf to f8E8M0FN type to extract
 exponent bits

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 21 ++++++++---
 mlir/test/Dialect/Arith/expand-ops.mlir       | 35 +++++++++++++++----
 2 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 17376e7bff91e..dc780bcf896bd 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -428,7 +428,13 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     Value inputOperand = op.getIn();
     Value scaleOperand = op.getScale();
-    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
           op, "scaling extf is not using scale operand of type f8E8M0FNU");
     }
@@ -451,11 +457,18 @@ struct ScalingTruncFOpConverter
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     Value inputOperand = op.getIn();
     Value scaleOperand = op.getScale();
-    if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
+    Type scaleTy = scaleOperand.getType();
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+      scaleTy = scaleOperand.getType();
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
           op, "scaling truncf is not using scale operand of type f8E8M0FNU");
     }
-    auto scaleTy = scaleOperand.getType();
 
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(op.getOut());
@@ -487,7 +500,7 @@ struct ScalingTruncFOpConverter
     Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
     Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
     Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
-    Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
+    Value scaleI32 = b.create<arith::ExtUIOp>(i32Ty, scaleI8);
     Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
     Value normalizedUnbiasedScale =
         b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 87f243d3cbdb4..60856f9d26b33 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -320,7 +320,7 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
 // SCHECK: %[[C127:.+]] = arith.constant 127 : i32
 // SCHECK: %[[CN127:.+]] = arith.constant -127 : i32
 // SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : f8E8M0FNU to i8
-// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : i8 to i32
+// SCHECK: %[[EXTI:.+]] = arith.extui %[[BITCASTI8]] : i8 to i32
 // SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] : i32
 // SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : i32
 // SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : i32
@@ -354,7 +354,7 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
 // SCHECK: %[[C127:.+]] = arith.constant dense<127> : vector<4xi32>
 // SCHECK: %[[CN127:.+]] = arith.constant dense<-127> : vector<4xi32>
 // SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : vector<4xf8E8M0FNU> to vector<4xi8>
-// SCHECK: %[[EXTI:.+]] = arith.extsi %[[BITCASTI8]] : vector<4xi8> to vector<4xi32>
+// SCHECK: %[[EXTI:.+]] = arith.extui %[[BITCASTI8]] : vector<4xi8> to vector<4xi32>
 // SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] :  vector<4xi32>
 // SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : vector<4xi32>
 // SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : vector<4xi32>
@@ -385,11 +385,20 @@ func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1:
 // SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
 // SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
 
+// -----
+func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+// SCHECK-LABLE: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
+// SCHECK: return
+
 // -----
 
-func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
     // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
-    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
     return %0 : f4E2M1FN
 }
 
@@ -478,9 +487,23 @@ func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
 
 // -----
 
-func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
-    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
     %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+    return %0 : f32 
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
     return %0 : f32
 }
 

>From 8f91e284ad46c2e20d06c958f02c53bf3a00e10a Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 17:05:11 +0000
Subject: [PATCH 18/30] USe floating point to normalize scales

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 62 +++++++++--------
 mlir/test/Dialect/Arith/expand-ops.mlir       | 66 ++++++++-----------
 2 files changed, 59 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dc780bcf896bd..5227a5b4279f5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -479,7 +479,6 @@ struct ScalingTruncFOpConverter
     Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
     Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
-    Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());
 
     if (inputETy.getIntOrFloatBitWidth() < 32) {
       inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
@@ -489,49 +488,48 @@ struct ScalingTruncFOpConverter
     inputTy = inputOperand.getType();
     inputETy = getElementTypeOrSelf(inputOperand);
 
-    // normalize scale by exponent of the max normal value in result type as per
-    // the OCP MXFP spec
+    // normalize scale by exponent of the max normal value (emax) in result type
+    // as per the OCP MXFP spec
     // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
+    // here this normalization is carried in f32. Therefore instead of
+    // subtraction it does the DivFOp
     const llvm::fltSemantics &resultFltSemantics =
         llvm::cast<FloatType>(resultETy).getFloatSemantics();
     int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
-    Value cMaxNormalExponent =
-        createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
-    Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
-    Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
-    Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
-    Value scaleI32 = b.create<arith::ExtUIOp>(i32Ty, scaleI8);
-    Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
-    Value normalizedUnbiasedScale =
-        b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
-    // clamp scale exponent as per spec
+    Value cEmax = createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
+    Value c1 = createConst(op->getLoc(), i32Ty, 1, rewriter);
+    Value cPow2 = b.create<arith::ShLIOp>(c1, cEmax);
+    Value cPow2F32 = b.create<arith::SIToFPOp>(f32Ty, cPow2);
+    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, scaleOperand);
+    // note that spec also does the clamping but it should only be done for
+    // underflows because diving by 2^emax will only make it smaller.
     // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
-    // upper clamp limit of 127 will be mapped to biased value of 255 and will
-    // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
-    // using arith.extf
-    Value clampUpperCond = b.create<arith::CmpIOp>(
-        arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
-    Value clampLowerCond = b.create<arith::CmpIOp>(
-        arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
+    Value scaleNormalizedF32 = b.create<arith::DivFOp>(scaleF32, cPow2F32);
+    // If it has underflown then scale will be a denorm FP32 number after
+    // division. Clamp underflows to 2^-127 as per the spec implementation
+    Value scaleNormalizedExponentF8 =
+        b.create<arith::TruncFOp>(scaleTy, scaleNormalizedF32);
+    Value scaleNormalizedExponentU8 =
+        b.create<arith::BitcastOp>(i8Ty, scaleNormalizedExponentF8);
+    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value scaleClampCond = b.create<arith::CmpIOp>(
+        arith::CmpIPredicate::eq, cI8Zero, scaleNormalizedExponentU8);
+    // 5.8e-39 is 2^-127, it is a denorm value in f32
+    float clampValue = 5.87747e-39;
+    Value scaleClampValue =
+        createFloatConst(op.getLoc(), f32Ty, clampValue, rewriter);
     Value clampedScale = b.create<arith::SelectOp>(
-        clampUpperCond, c127,
-        b.create<arith::SelectOp>(clampLowerCond, cNeg127,
-                                  normalizedUnbiasedScale));
-    Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
-    Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
-    Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
-    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
+        scaleClampCond, scaleClampValue, scaleNormalizedF32);
     // flush denorms by checking if exponent part of input operand is zero
     // or not.
     Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
     Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
-    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
-    Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
-                                            inputExponentU8);
+    Value inputFlushCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                   cI8Zero, inputExponentU8);
     Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
     Value flushedInput =
-        b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
-    Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
+        b.create<arith::SelectOp>(inputFlushCond, inputTyZero, inputOperand);
+    Value result = b.create<arith::DivFOp>(flushedInput, clampedScale);
     // propagate rounding mode and fast math attributes
     Value resultCast = b.create<arith::TruncFOp>(
         resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 60856f9d26b33..cd3ddc9760644 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -254,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
     %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
     return %0 : f8E8M0FNU
 }
-// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
@@ -268,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
     %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
     return %0 : f8E8M0FNU
 }
-// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU
 // CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
@@ -317,27 +317,23 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
 
 // SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
 // SCHECK: %[[C2:.+]] = arith.constant 2 : i32
-// SCHECK: %[[C127:.+]] = arith.constant 127 : i32
-// SCHECK: %[[CN127:.+]] = arith.constant -127 : i32
-// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : f8E8M0FNU to i8
-// SCHECK: %[[EXTI:.+]] = arith.extui %[[BITCASTI8]] : i8 to i32
-// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] : i32
-// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : i32
-// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : i32
-// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : i32
-// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : i32
-// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : i32
-// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : i32
-// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : i32 to i8
-// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : i8 to f8E8M0FNU
-// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : f8E8M0FNU to f32
+// SCHECK: %[[C1:.+]] = arith.constant 1 : i32
+// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : i32
+// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : i32 to f32
+// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : f32
+// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : f32 to f8E8M0FNU
+// SCHECK: %[[SCALEDIVI8:.+]] =  arith.bitcast %[[SCALEDIVF8]] : f8E8M0FNU to i8
+// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
+// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : i8
+// SCHECK: %[[CLAMPVAL:.+]] = arith.constant 5.877470e-39 : f32
+// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : f32 
 // SCHECK: %[[INPUTEXP:.+]] = arith.truncf %arg0 : f32 to f8E8M0FNU
 // SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : f8E8M0FNU to i8
-// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
 // SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : i8
 // SCHECK: %[[CF0:.+]] = arith.constant 0.000000e+00 : f32
 // SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %arg0 : f32
-// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : f32
 // SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
 // SCHECK: return %[[RESULT]]
 
@@ -351,27 +347,23 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
 // SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
 // SCHECK: %[[INPUTF32:.+]] = arith.extf %arg0 : vector<4xf16> to vector<4xf32>
 // SCHECK: %[[C2:.+]] = arith.constant dense<4> : vector<4xi32>
-// SCHECK: %[[C127:.+]] = arith.constant dense<127> : vector<4xi32>
-// SCHECK: %[[CN127:.+]] = arith.constant dense<-127> : vector<4xi32>
-// SCHECK: %[[BITCASTI8:.+]] =  arith.bitcast %arg1 : vector<4xf8E8M0FNU> to vector<4xi8>
-// SCHECK: %[[EXTI:.+]] = arith.extui %[[BITCASTI8]] : vector<4xi8> to vector<4xi32>
-// SCHECK: %[[UNBIASEDSCALE:.+]] = arith.subi %[[EXTI]], %[[C127]] :  vector<4xi32>
-// SCHECK: %[[NORMALIZED:.+]] = arith.subi %[[UNBIASEDSCALE]], %[[C2]] : vector<4xi32>
-// SCHECK: %[[UPPERCOND:.+]] = arith.cmpi sgt, %[[NORMALIZED]], %[[C127]] : vector<4xi32>
-// SCHECK: %[[LOWERCOND:.+]] = arith.cmpi slt, %[[NORMALIZED]], %[[CN127]] : vector<4xi32>
-// SCHECK: %[[LOWERSELECT:.+]] = arith.select %[[LOWERCOND]], %[[CN127]], %[[NORMALIZED]] : vector<4xi1>, vector<4xi32>
-// SCHECK: %[[UPPERSELECT:.+]] = arith.select %[[UPPERCOND]], %[[C127]], %[[LOWERSELECT]] : vector<4xi1>, vector<4xi32>
-// SCHECK: %[[BIASED:.+]] = arith.addi %[[UPPERSELECT]], %[[C127]] : vector<4xi32>
-// SCHECK: %[[BIASEDI8:.+]] = arith.trunci %[[BIASED]] : vector<4xi32> to vector<4xi8>
-// SCHECK: %[[BITCASTF8:.+]] = arith.bitcast %[[BIASEDI8]] : vector<4xi8> to vector<4xf8E8M0FNU>
-// SCHECK: %[[EXPSCALE:.+]] = arith.extf %[[BITCASTF8]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[C1:.+]] = arith.constant dense<1> : vector<4xi32>
+// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : vector<4xi32>
+// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : vector<4xi32> to vector<4xf32>
+// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : vector<4xf32>
+// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : vector<4xf32> to vector<4xf8E8M0FNU>
+// SCHECK: %[[SCALEDIVI8:.+]] =  arith.bitcast %[[SCALEDIVF8]] : vector<4xf8E8M0FNU> to vector<4xi8>
+// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
+// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : vector<4xi8>
+// SCHECK: %[[CLAMPVAL:.+]] = arith.constant dense<5.877470e-39> : vector<4xf32>
+// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : vector<4xi1>, vector<4xf32>
 // SCHECK: %[[INPUTEXP:.+]] = arith.truncf %[[INPUTF32]] : vector<4xf32> to vector<4xf8E8M0FNU>
 // SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : vector<4xf8E8M0FNU> to vector<4xi8> 
-// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
 // SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : vector<4xi8>
 // SCHECK: %[[CF0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
 // SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %[[INPUTF32]] : vector<4xi1>, vector<4xf32>
-// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[EXPSCALE]] : vector<4xf32>
+// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : vector<4xf32>
 // SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf32> to vector<4xf6E3M2FN>
 // SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
 
@@ -381,7 +373,7 @@ func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1:
     %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
     return %0 : vector<4xf6E3M2FN>
 }
-// SCHECK-LABLE: @scaling_truncf_propagate_rounding_mode
+// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
 // SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
 // SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
 
@@ -390,7 +382,7 @@ func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f
     %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
     return %0 : f4E2M1FN
 }
-// SCHECK-LABLE: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
 // SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
 // SCHECK: return
 
@@ -428,7 +420,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
     return %0 : f16
 }
 
-// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f16
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
 // CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
 // CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32

>From dc7b67f633b24972e6ac1b1b1d2f9c45cc29586f Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 18:18:00 +0000
Subject: [PATCH 19/30] Rewrite description

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 49 +++++++++++--------
 1 file changed, 29 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 18b23fa0eec48..7cdec77e6005a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,19 +1229,23 @@ def Arith_ScalingExtFOp
   let summary =
       "cast from floating-point to larger floating-point using provided scales";
   let description = [{
-    Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. 
-    Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.  
-    Scale is usually calculated per block.  
-    Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. 
-    Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. 
-    Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.  
-    In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
+    This operations upcasts quantized floating point value using provided scales values. 
+    It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
+    Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
+    If Scales are calculated per block where blockSize != 1 then, Scales may require broadcasting to make this operation elementwise. 
+    For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
+    then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
+    scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.
+    Scales could also be of some other shape as long as it is broadcast compatible with input e.g. <1x1x...(dimN/blockSize)x1>. 
+    In this example, before calling into arith.scaling_extf, scales must be broadcasted to <dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>. 
+    Note that there could be multiple quantization axes. 
+    Internally, `arith.scaling_extf` would do following:
     ```
     resultTy = get_type(result) 
     scaleTy  = get_type(scale)
     inputTy = get_type(input)
+    assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
     scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
-    scale.bcast = broadcast_to_same_shape_as(result)
     scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
     input.extf = arith.extf(input) : inputTy to resultTy
     result = arith.mulf(scale.extf, input.extf)
@@ -1335,26 +1339,31 @@ def Arith_ScalingTruncFOp
   let summary =
       "cast from floating-point to narrower floating-point with scales";
   let description = [{
-    This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
-    This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. 
-    Therefore original input of shape `<dim1 x dim2 ... dimN>` can be thought of as of shape `<dim1 x dim2 x ... (dimN / blockSize) x blockSize>`, 
-    assuming quantization axis is the last axis. 
-    Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`. 
-    `arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then 
-    Scales must be broadcast appropriately to ensure they are of the same shape as the input operand.
-    Internally arith.scaling_truncf does the following:
+    This operation quantizes input using provided scales value. 
+    It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
+    Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
+    If Scales are calculated per block where blockSize != 1 then, Scales may require broadcasting to make this operation elementwise. 
+    For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
+    then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
+    scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.
+    Scales could also be of some other shape as long as it is broadcast compatible with input e.g. <1x1x...(dimN/blockSize)x1>. 
+    In this example, before calling into arith.scaling_truncf, scales must be broadcasted to <dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>. 
+    Note that there could be multiple quantization axes. 
+    Internally, `arith.scaling_truncf` would do following:
     ```
     scaleETy = get_type(scale)
     inputETy = get_type(input)
     resultETy = get_type(result)
-    scale.bcast = broadcast_to_same_shape_as(input)
+    // prepare Scale values with normalization and clamping
     scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
     scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
-    result = arith.divf(input, scale.extf)
+    scale.normalize = arith.divf(scale.extf, emax)  // emax is calculated as exponent of the largest normal value in quantized type. 
+    scale.clamped = clamp(scale.normalize) // clamp underflows
+    input.flused = flush_denorms(input)
+    result = arith.divf(input.flushed, scale.clamped)
     result.cast = arith.truncf(result, resultETy)
     ```
-    OCP MXFP spec flushes denorm input value before quantization. NaNs are propagated. 
-
+    Flushing of denorms in input and scale normalization with emax is added as per the OCP MXFP spec. 
   }];
   let hasVerifier = 1;
   let assemblyFormat =

>From 109ddc57f1eb0c2dd1f947371bc5d4033ece21eb Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 18:23:36 +0000
Subject: [PATCH 20/30] change error message

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 5227a5b4279f5..9375f7d546fb2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -436,7 +436,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     }
     if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
-          op, "scaling extf is not using scale operand of type f8E8M0FNU");
+          op, "scaling_extf is using scales of type which can not be converted "
+              "to f8E8M0FNU");
     }
     Type resultTy = op.getType();
     // extf on scale will essentially create f32 number that is 2^scale and will
@@ -467,7 +468,8 @@ struct ScalingTruncFOpConverter
     }
     if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
-          op, "scaling truncf is not using scale operand of type f8E8M0FNU");
+          op, "scaling_truncf is using scales type which can not be converted "
+              "to f8E8M0FNU");
     }
 
     Type resultTy = op.getType();

>From f3d9865850f77a2b8c60548954070497d2820047 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 18:23:43 +0000
Subject: [PATCH 21/30] some nits

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 7cdec77e6005a..5198b3d7ef89a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1232,7 +1232,7 @@ def Arith_ScalingExtFOp
     This operations upcasts quantized floating point value using provided scales values. 
     It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
     Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
-    If Scales are calculated per block where blockSize != 1 then, Scales may require broadcasting to make this operation elementwise. 
+    If scales are calculated per block where blockSize != 1 then, scales may require broadcasting to make this operation elementwise. 
     For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
     then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
     scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.
@@ -1342,7 +1342,7 @@ def Arith_ScalingTruncFOp
     This operation quantizes input using provided scales value. 
     It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
     Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
-    If Scales are calculated per block where blockSize != 1 then, Scales may require broadcasting to make this operation elementwise. 
+    If scales are calculated per block where blockSize != 1 then, scales may require broadcasting to make this operation elementwise. 
     For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
     then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
     scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.

>From 3ccb208dc99c367073ec0ba468f1cf169452d091 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 18:31:12 +0000
Subject: [PATCH 22/30] Formatting

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 77 ++++++++++++-------
 1 file changed, 50 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 5198b3d7ef89a..0dc636ce00aa7 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,17 +1229,26 @@ def Arith_ScalingExtFOp
   let summary =
       "cast from floating-point to larger floating-point using provided scales";
   let description = [{
-    This operations upcasts quantized floating point value using provided scales values. 
-    It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
-    Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
-    If scales are calculated per block where blockSize != 1 then, scales may require broadcasting to make this operation elementwise. 
-    For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
-    then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
-    scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.
-    Scales could also be of some other shape as long as it is broadcast compatible with input e.g. <1x1x...(dimN/blockSize)x1>. 
-    In this example, before calling into arith.scaling_extf, scales must be broadcasted to <dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>. 
-    Note that there could be multiple quantization axes. 
-    Internally, `arith.scaling_extf` would do following:
+  This operation upcasts quantized floating-point values using provided scale 
+  values. It expects both scales and the input operand to be of the same shape, 
+  making the operation elementwise. Scales are usually calculated per block 
+  following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+  If scales are calculated per block where blockSize != 1, then scales may 
+  require broadcasting to make this operation elementwise. For example, let's 
+  say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
+  assuming quantization happens on the last axis, the input can be reshaped to 
+  `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
+  per block on the last axis. Therefore, scales will be of shape 
+  `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
+  shape as long as it is broadcast compatible with the input, e.g., 
+  `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+  In this example, before calling into `arith.scaling_extf`, scales must be 
+  broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
+  that there could be multiple quantization axes. Internally, 
+  `arith.scaling_extf` would perform the following:
+ 
     ```
     resultTy = get_type(result) 
     scaleTy  = get_type(scale)
@@ -1250,11 +1259,13 @@ def Arith_ScalingExtFOp
     input.extf = arith.extf(input) : inputTy to resultTy
     result = arith.mulf(scale.extf, input.extf)
     ```
-    It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN.
+    It propagates NaN values. Therefore, if either scale or the input element 
+    contains NaN, then the output element value will also be a NaN.
   }];
   let hasVerifier = 1;
   let assemblyFormat =
-      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` 
+      type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1339,17 +1350,26 @@ def Arith_ScalingTruncFOp
   let summary =
       "cast from floating-point to narrower floating-point with scales";
   let description = [{
-    This operation quantizes input using provided scales value. 
-    It expects both scales and input operand to be of same shape and therefore operation is an elementwise operation. 
-    Scales are usually calculated per block following OCP MXFP spec as described in  https://arxiv.org/abs/2310.10537.
-    If scales are calculated per block where blockSize != 1 then, scales may require broadcasting to make this operation elementwise. 
-    For example, let's say input is of shape <dim1 x dim2 x ... dimN> then given blockSize != 1 and assuming quantization happens on last axis, 
-    then input can be reshaped to `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated per block on last axis. Therefore 
-    scales will be of shape `<dim1 x dim2 x .. (dimN/blockSize) x 1>`.
-    Scales could also be of some other shape as long as it is broadcast compatible with input e.g. <1x1x...(dimN/blockSize)x1>. 
-    In this example, before calling into arith.scaling_truncf, scales must be broadcasted to <dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>. 
-    Note that there could be multiple quantization axes. 
-    Internally, `arith.scaling_truncf` would do following:
+    This operation quantizes input using the provided scale values. It expects 
+    both scales and the input operand to be of the same shape and, therefore, 
+    makes the operation elementwise. Scales are usually calculated per block 
+    following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+    If scales are calculated per block where blockSize != 1, scales may require 
+    broadcasting to make this operation elementwise. For example, let's say the 
+    input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
+    assuming quantization happens on the last axis, the input can be reshaped to 
+    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
+    per block on the last axis. Therefore, scales will be of shape 
+    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
+    shape as long as it is broadcast compatible with the input, e.g., 
+    `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+    In this example, before calling into `arith.scaling_truncf`, scales must be 
+    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
+    that there could be multiple quantization axes. Internally, 
+    `arith.scaling_truncf` would perform the following:
+ 
     ```
     scaleETy = get_type(scale)
     inputETy = get_type(input)
@@ -1357,17 +1377,20 @@ def Arith_ScalingTruncFOp
     // prepare Scale values with normalization and clamping
     scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
     scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
-    scale.normalize = arith.divf(scale.extf, emax)  // emax is calculated as exponent of the largest normal value in quantized type. 
+    // emax is calculated as exponent of the largest normal value in quantized type.
+    scale.normalize = arith.divf(scale.extf, emax)   
     scale.clamped = clamp(scale.normalize) // clamp underflows
     input.flused = flush_denorms(input)
     result = arith.divf(input.flushed, scale.clamped)
     result.cast = arith.truncf(result, resultETy)
     ```
-    Flushing of denorms in input and scale normalization with emax is added as per the OCP MXFP spec. 
+    Flushing of denorms in input and scale normalization with emax is added as per 
+    the OCP MXFP spec. 
   }];
   let hasVerifier = 1;
   let assemblyFormat =
-      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
+      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` 
+      type($in) `,` type($scale) `to` type($out)}];
 }
 
 //===----------------------------------------------------------------------===//

>From d1543414578abf95a495b4eb6fe9b6201de8e9f6 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 30 May 2025 18:37:56 +0000
Subject: [PATCH 23/30] Change comment

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 9375f7d546fb2..40015c693ef41 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -440,8 +440,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
               "to f8E8M0FNU");
     }
     Type resultTy = op.getType();
-    // extf on scale will essentially create f32 number that is 2^scale and will
-    // also propagate NaNs
+    // extf on scale will essentially create floating point number
+    // of type resulTy that is 2^scale and will also propagate NaNs
     Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
     Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
     Value result = b.create<arith::MulFOp>(inputExt, scaleExt);

>From d8a76fa94ea0214f7a7643fbb1d9a1dceadb7538 Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Sat, 31 May 2025 08:23:25 -0400
Subject: [PATCH 24/30] Update
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index bba335431aee6..e0a4567d6f406 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,7 +62,7 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 
-/// Add patterns to expland scaling ExtF/TruncF ops to equivalent arith ops
+/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
 void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
 
 /// Add patterns to expand Arith ops.

>From a0aa490e4decec9139d9e0de2d36f768e25a870c Mon Sep 17 00:00:00 2001
From: Umang Yadav <29876643+umangyadav at users.noreply.github.com>
Date: Sat, 31 May 2025 08:23:48 -0400
Subject: [PATCH 25/30] Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 40015c693ef41..6b506f3f6431b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -504,7 +504,7 @@ struct ScalingTruncFOpConverter
     Value cPow2F32 = b.create<arith::SIToFPOp>(f32Ty, cPow2);
     Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, scaleOperand);
     // note that spec also does the clamping but it should only be done for
-    // underflows because diving by 2^emax will only make it smaller.
+    // underflows because dividing by 2^emax will only make it smaller.
     // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
     Value scaleNormalizedF32 = b.create<arith::DivFOp>(scaleF32, cPow2F32);
     // If it has underflown then scale will be a denorm FP32 number after

>From ff66dadd716709887a6e3a47955e1656444df2e5 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 31 May 2025 12:28:03 +0000
Subject: [PATCH 26/30] address some review comments

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td  | 4 ++--
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 1 -
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0dc636ce00aa7..1dcf9c05c709f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,7 +1227,7 @@ def Arith_ScalingExtFOp
           OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
       Results<(outs FloatLike:$out)> {
   let summary =
-      "cast from floating-point to larger floating-point using provided scales";
+      "Upcasts quantized floats using provided scales values following OCP MXFP Spec";
   let description = [{
   This operation upcasts quantized floating-point values using provided scale 
   values. It expects both scales and the input operand to be of the same shape, 
@@ -1348,7 +1348,7 @@ def Arith_ScalingTruncFOp
           OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
       Results<(outs FloatLike:$out)> {
   let summary =
-      "cast from floating-point to narrower floating-point with scales";
+      "Downcasts input floating point values using provided scales values following OCP MXFP Spec";
   let description = [{
     This operation quantizes input using the provided scale values. It expects 
     both scales and the input operand to be of the same shape and, therefore, 
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6b506f3f6431b..fa092fae6b185 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -11,7 +11,6 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/APFloat.h"

>From f7c1b795d1ebe7a8d570690e572dddfa8e39fa70 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Mon, 2 Jun 2025 15:09:42 +0000
Subject: [PATCH 27/30] Fix docs

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 1dcf9c05c709f..c2b764065ab27 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1255,7 +1255,7 @@ def Arith_ScalingExtFOp
     inputTy = get_type(input)
     assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
     scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
-    scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
+    scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
     input.extf = arith.extf(input) : inputTy to resultTy
     result = arith.mulf(scale.extf, input.extf)
     ```
@@ -1375,7 +1375,7 @@ def Arith_ScalingTruncFOp
     inputETy = get_type(input)
     resultETy = get_type(result)
     // prepare Scale values with normalization and clamping
-    scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
+    scale.exponent = arith.truncf(scale) : scaleETy to f8E8M0
     scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
     // emax is calculated as exponent of the largest normal value in quantized type.
     scale.normalize = arith.divf(scale.extf, emax)   

>From 45e7dba87ade64aab11e8cb0d44baac7772a700f Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 6 Jun 2025 17:15:42 +0000
Subject: [PATCH 28/30] Simplify arith.scaling_truncf to just do division and
 trunction. Denorm flushign on input should be carried out using specified
 fastMath flag. Scales are assumed to be normalized and clamped.

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  28 +++--
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 103 ++++--------------
 mlir/test/Dialect/Arith/expand-ops.mlir       |  43 +-------
 3 files changed, 42 insertions(+), 132 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c2b764065ab27..5b1fa529465f4 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1354,6 +1354,10 @@ def Arith_ScalingTruncFOp
     both scales and the input operand to be of the same shape and, therefore, 
     makes the operation elementwise. Scales are usually calculated per block 
     following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+    Users are required to normalize and clamp the scales as necessary before calling
+    passing them to this operation.  OCP MXFP spec also does the flushing of denorms
+    on the input operand, which should be handled during lowering by passing appropriate 
+    fastMath flag to this operation. 
 
     If scales are calculated per block where blockSize != 1, scales may require 
     broadcasting to make this operation elementwise. For example, let's say the 
@@ -1369,23 +1373,17 @@ def Arith_ScalingTruncFOp
     broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
     that there could be multiple quantization axes. Internally, 
     `arith.scaling_truncf` would perform the following:
- 
+
     ```
-    scaleETy = get_type(scale)
-    inputETy = get_type(input)
-    resultETy = get_type(result)
-    // prepare Scale values with normalization and clamping
-    scale.exponent = arith.truncf(scale) : scaleETy to f8E8M0
-    scale.extf = arith.extf(scale.exponent)  : f8E8M0 to inputETy
-    // emax is calculated as exponent of the largest normal value in quantized type.
-    scale.normalize = arith.divf(scale.extf, emax)   
-    scale.clamped = clamp(scale.normalize) // clamp underflows
-    input.flused = flush_denorms(input)
-    result = arith.divf(input.flushed, scale.clamped)
-    result.cast = arith.truncf(result, resultETy)
+    scaleTy = get_type(scale)
+    inputTy = get_type(input)
+    resultTy = get_type(result)
+    assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
+    result = arith.divf(input, scale.extf)
+    result.cast = arith.truncf(result, resultTy)
     ```
-    Flushing of denorms in input and scale normalization with emax is added as per 
-    the OCP MXFP spec. 
   }];
   let hasVerifier = 1;
   let assemblyFormat =
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index fa092fae6b185..96191feae6182 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,8 +13,6 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/APFloat.h"
-#include <cstdint>
 
 namespace mlir {
 namespace arith {
@@ -25,16 +23,6 @@ namespace arith {
 
 using namespace mlir;
 
-static Value createFloatConst(Location loc, Type type, float value,
-                              PatternRewriter &rewriter) {
-  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
-  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
-    return rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(shapedTy, attr));
-  }
-  return rewriter.create<arith::ConstantOp>(loc, attr);
-}
-
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
@@ -368,7 +356,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
     Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
     if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
+      result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
+                                         op.getFastmathAttr());
     } else if (resultETy.getIntOrFloatBitWidth() > 32) {
       result = b.create<arith::ExtFOp>(resultTy, result);
     }
@@ -406,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     if (operandETy.getIntOrFloatBitWidth() < 32) {
-      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+      operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
     } else if (operandETy.getIntOrFloatBitWidth() > 32) {
-      operand = b.create<arith::TruncFOp>(f32Ty, operand);
+      operand = b.create<arith::TruncFOp>(
+          f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
     }
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -431,7 +421,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     // allow implicit exponent extraction from 16/32 bits floats
     if (scaleETy.getIntOrFloatBitWidth() >= 16) {
       scaleETy = b.getF8E8M0Type();
-      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand, nullptr,
+                                               op.getFastmathAttr());
     }
     if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
@@ -441,14 +432,22 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     Type resultTy = op.getType();
     // extf on scale will essentially create floating point number
     // of type resulTy that is 2^scale and will also propagate NaNs
-    Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
-    Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
-    Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
+    Value scaleExt =
+        b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
+    Value inputExt =
+        b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
+    Value result =
+        b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+/*
+Expands arith.ScalingTruncFOp(in, scale) into
+  scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
+  result = arith.truncf(in / (2^scale))
+ */
 struct ScalingTruncFOpConverter
     : public OpRewritePattern<arith::ScalingTruncFOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -470,68 +469,14 @@ struct ScalingTruncFOpConverter
           op, "scaling_truncf is using scales type which can not be converted "
               "to f8E8M0FNU");
     }
-
     Type resultTy = op.getType();
-    Type resultETy = getElementTypeOrSelf(op.getOut());
-
     Type inputTy = inputOperand.getType();
-    Type inputETy = getElementTypeOrSelf(inputOperand);
-
-    Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
-    Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
-    Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
-
-    if (inputETy.getIntOrFloatBitWidth() < 32) {
-      inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
-    } else if (inputETy.getIntOrFloatBitWidth() > 32) {
-      inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
-    }
-    inputTy = inputOperand.getType();
-    inputETy = getElementTypeOrSelf(inputOperand);
-
-    // normalize scale by exponent of the max normal value (emax) in result type
-    // as per the OCP MXFP spec
-    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
-    // here this normalization is carried in f32. Therefore instead of
-    // subtraction it does the DivFOp
-    const llvm::fltSemantics &resultFltSemantics =
-        llvm::cast<FloatType>(resultETy).getFloatSemantics();
-    int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
-    Value cEmax = createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
-    Value c1 = createConst(op->getLoc(), i32Ty, 1, rewriter);
-    Value cPow2 = b.create<arith::ShLIOp>(c1, cEmax);
-    Value cPow2F32 = b.create<arith::SIToFPOp>(f32Ty, cPow2);
-    Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, scaleOperand);
-    // note that spec also does the clamping but it should only be done for
-    // underflows because dividing by 2^emax will only make it smaller.
-    // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
-    Value scaleNormalizedF32 = b.create<arith::DivFOp>(scaleF32, cPow2F32);
-    // If it has underflown then scale will be a denorm FP32 number after
-    // division. Clamp underflows to 2^-127 as per the spec implementation
-    Value scaleNormalizedExponentF8 =
-        b.create<arith::TruncFOp>(scaleTy, scaleNormalizedF32);
-    Value scaleNormalizedExponentU8 =
-        b.create<arith::BitcastOp>(i8Ty, scaleNormalizedExponentF8);
-    Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
-    Value scaleClampCond = b.create<arith::CmpIOp>(
-        arith::CmpIPredicate::eq, cI8Zero, scaleNormalizedExponentU8);
-    // 5.8e-39 is 2^-127, it is a denorm value in f32
-    float clampValue = 5.87747e-39;
-    Value scaleClampValue =
-        createFloatConst(op.getLoc(), f32Ty, clampValue, rewriter);
-    Value clampedScale = b.create<arith::SelectOp>(
-        scaleClampCond, scaleClampValue, scaleNormalizedF32);
-    // flush denorms by checking if exponent part of input operand is zero
-    // or not.
-    Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
-    Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
-    Value inputFlushCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
-                                                   cI8Zero, inputExponentU8);
-    Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
-    Value flushedInput =
-        b.create<arith::SelectOp>(inputFlushCond, inputTyZero, inputOperand);
-    Value result = b.create<arith::DivFOp>(flushedInput, clampedScale);
-    // propagate rounding mode and fast math attributes
+    // this will create a floating point number of type
+    // inputTy that is 2^scale and will also propagate NaNs
+    scaleOperand =
+        b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
+    Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
+                                           op.getFastmathAttr());
     Value resultCast = b.create<arith::TruncFOp>(
         resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
     rewriter.replaceOp(op, resultCast);
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index cd3ddc9760644..be254cb0405fd 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -316,24 +316,8 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
 }
 
 // SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
-// SCHECK: %[[C2:.+]] = arith.constant 2 : i32
-// SCHECK: %[[C1:.+]] = arith.constant 1 : i32
-// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : i32
-// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : i32 to f32
 // SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
-// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : f32
-// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : f32 to f8E8M0FNU
-// SCHECK: %[[SCALEDIVI8:.+]] =  arith.bitcast %[[SCALEDIVF8]] : f8E8M0FNU to i8
-// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
-// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : i8
-// SCHECK: %[[CLAMPVAL:.+]] = arith.constant 5.877470e-39 : f32
-// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : f32 
-// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %arg0 : f32 to f8E8M0FNU
-// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : f8E8M0FNU to i8
-// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : i8
-// SCHECK: %[[CF0:.+]] = arith.constant 0.000000e+00 : f32
-// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %arg0 : f32
-// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
 // SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
 // SCHECK: return %[[RESULT]]
 
@@ -345,26 +329,9 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
 }
 
 // SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
-// SCHECK: %[[INPUTF32:.+]] = arith.extf %arg0 : vector<4xf16> to vector<4xf32>
-// SCHECK: %[[C2:.+]] = arith.constant dense<4> : vector<4xi32>
-// SCHECK: %[[C1:.+]] = arith.constant dense<1> : vector<4xi32>
-// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : vector<4xi32>
-// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : vector<4xi32> to vector<4xf32>
-// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : vector<4xf32>
-// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : vector<4xf32> to vector<4xf8E8M0FNU>
-// SCHECK: %[[SCALEDIVI8:.+]] =  arith.bitcast %[[SCALEDIVF8]] : vector<4xf8E8M0FNU> to vector<4xi8>
-// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
-// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : vector<4xi8>
-// SCHECK: %[[CLAMPVAL:.+]] = arith.constant dense<5.877470e-39> : vector<4xf32>
-// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : vector<4xi1>, vector<4xf32>
-// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %[[INPUTF32]] : vector<4xf32> to vector<4xf8E8M0FNU>
-// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : vector<4xf8E8M0FNU> to vector<4xi8> 
-// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : vector<4xi8>
-// SCHECK: %[[CF0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %[[INPUTF32]] : vector<4xi1>, vector<4xf32>
-// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
 // SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
 
 // -----
@@ -374,7 +341,7 @@ func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1:
     return %0 : vector<4xf6E3M2FN>
 }
 // SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
-// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf16> to vector<4xf6E3M2FN>
 // SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
 
 // -----

>From 80061d6fbd1d45f13a7ccd698128fb2dcaed034e Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 6 Jun 2025 23:40:15 +0000
Subject: [PATCH 29/30] address review comments and add tests

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  8 ++-
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 11 +++--
 mlir/test/Dialect/Arith/expand-ops.mlir       | 49 +++++++++++++++++--
 3 files changed, 55 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 5b1fa529465f4..3ff27967c75b7 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,9 +1227,9 @@ def Arith_ScalingExtFOp
           OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
       Results<(outs FloatLike:$out)> {
   let summary =
-      "Upcasts quantized floats using provided scales values following OCP MXFP Spec";
+      "Upcasts input floats using provided scales values following OCP MXFP Spec";
   let description = [{
-  This operation upcasts quantized floating-point values using provided scale 
+  This operation upcasts input floating-point values using provided scale 
   values. It expects both scales and the input operand to be of the same shape, 
   making the operation elementwise. Scales are usually calculated per block 
   following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
@@ -1253,7 +1253,6 @@ def Arith_ScalingExtFOp
     resultTy = get_type(result) 
     scaleTy  = get_type(scale)
     inputTy = get_type(input)
-    assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
     scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
     scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
     input.extf = arith.extf(input) : inputTy to resultTy
@@ -1350,7 +1349,7 @@ def Arith_ScalingTruncFOp
   let summary =
       "Downcasts input floating point values using provided scales values following OCP MXFP Spec";
   let description = [{
-    This operation quantizes input using the provided scale values. It expects 
+    This operation downcasts input using the provided scale values. It expects 
     both scales and the input operand to be of the same shape and, therefore, 
     makes the operation elementwise. Scales are usually calculated per block 
     following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
@@ -1378,7 +1377,6 @@ def Arith_ScalingTruncFOp
     scaleTy = get_type(scale)
     inputTy = get_type(input)
     resultTy = get_type(result)
-    assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
     scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
     scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
     result = arith.divf(input, scale.extf)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 96191feae6182..534aff9562b7a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -359,7 +359,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
       result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
                                          op.getFastmathAttr());
     } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
+      result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -417,11 +417,13 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     Value inputOperand = op.getIn();
     Value scaleOperand = op.getScale();
+    Type scaleTy = scaleOperand.getType();
     Type scaleETy = getElementTypeOrSelf(scaleOperand);
     // allow implicit exponent extraction from 16/32 bits floats
     if (scaleETy.getIntOrFloatBitWidth() >= 16) {
       scaleETy = b.getF8E8M0Type();
-      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand, nullptr,
+      scaleTy = cloneToShapedType(scaleTy, scaleETy);
+      scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
                                                op.getFastmathAttr());
     }
     if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
@@ -461,8 +463,9 @@ struct ScalingTruncFOpConverter
     // allow implicit exponent extraction from 16/32 bits floats
     if (scaleETy.getIntOrFloatBitWidth() >= 16) {
       scaleETy = b.getF8E8M0Type();
-      scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
-      scaleTy = scaleOperand.getType();
+      scaleTy = cloneToShapedType(scaleTy, scaleETy);
+      scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
+                                               op.getFastmathAttr());
     }
     if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
       return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index be254cb0405fd..db1349feaff3a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -336,15 +336,19 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
 
 // -----
 
-func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
-    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
     return %0 : vector<4xf6E3M2FN>
 }
-// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
-// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf16> to vector<4xf6E3M2FN>
+// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
+// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
 // SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
 
 // -----
+
 func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
     %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
     return %0 : f4E2M1FN
@@ -353,6 +357,15 @@ func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f
 // SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
 // SCHECK: return
 
+// -----
+func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
+    return %0 : vector<4xf4E2M1FN>
+}
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: return
+
 // -----
 
 func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
@@ -507,6 +520,34 @@ func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector
 
 // -----
 
+func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
 func.func @maxsi(%a: i32, %b: i32) -> i32 {
   %result = arith.maxsi %a, %b : i32
   return %result : i32

>From a38ac5ea3b3cd71f1e4692fb8c74f1d437f02c03 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Fri, 6 Jun 2025 23:40:59 +0000
Subject: [PATCH 30/30] Formatting

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3ff27967c75b7..adc27ae6bdafb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1226,8 +1226,8 @@ def Arith_ScalingExtFOp
       Arguments<(ins FloatLike:$in, FloatLike:$scale,
           OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
       Results<(outs FloatLike:$out)> {
-  let summary =
-      "Upcasts input floats using provided scales values following OCP MXFP Spec";
+  let summary = "Upcasts input floats using provided scales values following "
+                "OCP MXFP Spec";
   let description = [{
   This operation upcasts input floating-point values using provided scale 
   values. It expects both scales and the input operand to be of the same shape, 
@@ -1346,8 +1346,8 @@ def Arith_ScalingTruncFOp
           OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
           OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
       Results<(outs FloatLike:$out)> {
-  let summary =
-      "Downcasts input floating point values using provided scales values following OCP MXFP Spec";
+  let summary = "Downcasts input floating point values using provided scales "
+                "values following OCP MXFP Spec";
   let description = [{
     This operation downcasts input using the provided scale values. It expects 
     both scales and the input operand to be of the same shape and, therefore, 



More information about the Mlir-commits mailing list