[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 00:13:26 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/144157

>From 6ef31a776b020ea5fb34ff6882aba7fe67ff421b Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 9 Jun 2025 21:06:41 +0000
Subject: [PATCH 1/8] initial implementation to fix

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 206 ++++++++++++++++++
 1 file changed, 206 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..87483ced8e5cf 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+    
+    // create constants to extract mantissa / exponent
+    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+    // create constants for NaNs
+    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     return success();
   }
 };
+/*
+Conversion from F32 to F4E2M1 according to the OCP Spec:
+www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+
+The spec requiers us to perform Round to Nearest, Ties to Even.
+
+This means that after rounding, we should break ties by choosing the option
+which results in a mantissa of 0 in the least significant digit.
+
+Table of representable values in F4E2M1:
+
+Note: x is sign bit
+| Binary | Value ( + / - ) 
+| x000   | 0.0
+| x001   | 0.5
+| x010   | 1.0
+| x011   | 1.5
+| x100   | 2.0
+| x101   | 3.0
+| x110   | 4.0
+| x111   | 6.0
+
+Conversion procedure: 
+Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
+Create bias adjusted exponent, E_1 <- E_0 - 126
+If E_0 <= 0111 1110
+    M_1 <- 0, E_1 <- 00
+    end
+if E_1 == 00 (special case for almost subnormal)
+    if we must round up (M_0 >= 10000000000000000000000)
+        M_1 <- 0 
+        E_1 <- 01
+    else
+        M_1 <- 1
+    end
+Else if E_1 > 00
+    roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
+    if roundToEven
+        M_1 <- 0 
+    else
+        M_1 <- 1
+    If M_0 >= 11000000000000000000000
+        increment E_1
+    If E_1 > 11 (saturate if beyond range)
+        M_1 <- 1, E_1 <- 11
+end
+*/
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    // Constants
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+    Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+    Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
+    Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+    Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+  
+    Value cF32MantissaWidth = c0x00000017; // 23
+    Value cF4MantissaWidth = c0x1; // 1
+    Value cF32SignExpWidth = c0x00000009; // 9
+    Value cF32MantissaMask = c0x007fffff;
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    
+    Value cSubnormalExp = c0x7e; // 126
+
+    // Regular case
+    Value biasAdjustment = c0x7e; // 126
+    Value cRoundUp = c0x00600000; // 110 0000...
+    Value cRoundDown = c0x00200000; // 010 0000...
+    Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
+    Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
+    // If we round up or down to even, set mantissa to 0
+    Value shouldRoundUp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
+    Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
+                                                    man23Bits, cRoundDown);
+    // dont need to worry about saturation this way
+    f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
+    Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+    Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
+    f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
+
+    // Bordering subnormal
+    Value cSubnormalRoundUp =
+        createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
+    Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
+    Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
+    Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
+                                                  man23Bits, cSubnormalRoundUp);
+    f4EdgeRounded =
+        b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
+    Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
+                                           cSubnormalExp);
+
+    // Subnormal
+    Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
+    Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
+                                          cSubnormalExp);
+
+    // create constants to extract mantissa / exponent
+    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
+    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
+
+    // create constants for NaNs
+    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
+    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
 
 /*
 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type

>From abc108e9f40adfcb1ac5b30680626ea5f1a35d33 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 12 Jun 2025 16:26:50 +0000
Subject: [PATCH 2/8] intermediate commit

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 79 ++++++++++++++++---
 1 file changed, 66 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 87483ced8e5cf..0fb802b82fffb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+                         PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
   if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
 | x111   | 6.0
 
 Conversion procedure: 
+
+Step 1: Clamp to max f4 value
+
+Step 2: convert exponent, if signed int comparison <= 0, set 0
+
+Step 3: if mantissa[1:] greater than 1000000, add 1
+
 Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
 Create bias adjusted exponent, E_1 <- E_0 - 126
 If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
     Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
+    Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+    
     Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
     Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
     Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
+    Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
     Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-    
-    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
-  
     Value cF32MantissaWidth = c0x00000017; // 23
-    Value cF4MantissaWidth = c0x1; // 1
-    Value cF32SignExpWidth = c0x00000009; // 9
+    Value cF4MantissaWidth = c0x1;         // 1
+    Value cF32SignExpWidth = c0x00000009;  // 9
+    Value cF32FirstBitMask = c0x00400000;
+    Value cF32Last22BitMask = c0x003fffff;
     Value cF32MantissaMask = c0x007fffff;
+
+    // Step 1: Clamp to bounds.
+    Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+    Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+  
+    // Step 2: Convert exponent by adjusting bias.
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
-    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
-    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value biasAdjustment = c0x0000007e; // 126
+    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
     
-    Value cSubnormalExp = c0x7e; // 126
+    // Step 0: Special consideration for conversion to 0.5.
+    Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
+    Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
+    Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+
+    // Step 3: Set mantissa to first bit.
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+    Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+    Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+    // Step 4: Round up if necessary.
+    Value cRound = c0x00200000; // 010 0000...
+    Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+    Value shouldRound =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 
     // Regular case
-    Value biasAdjustment = c0x7e; // 126
-    Value cRoundUp = c0x00600000; // 110 0000...
-    Value cRoundDown = c0x00200000; // 010 0000...
-    Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
+
+    
+
     Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
     Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
     // If we round up or down to even, set mantissa to 0

>From 63f2337c538f463f55104b4470aa88a07d364433 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 13 Jun 2025 13:06:00 -0700
Subject: [PATCH 3/8] Initial implementation of truncf

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 265 +++++-------------
 1 file changed, 66 insertions(+), 199 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0fb802b82fffb..e35154e8ae32b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
-struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::ExtFOp op,
-                                PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value operand = op.getOperand();
-    Type operandTy = operand.getType();
-    Type resultTy = op.getType();
-    Type operandETy = getElementTypeOrSelf(operandTy);
-    Type resultETy = getElementTypeOrSelf(resultTy);
-
-    if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
-      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
-    }
-
-    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
-    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
-    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
-    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-    
-    // create constants to extract mantissa / exponent
-    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
-    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
-    // create constants for NaNs
-    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
-    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
-    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
-    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
-    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
-    Value isNan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
-    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     return success();
   }
 };
-/*
-Conversion from F32 to F4E2M1 according to the OCP Spec:
-www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
-
-The spec requiers us to perform Round to Nearest, Ties to Even.
-
-This means that after rounding, we should break ties by choosing the option
-which results in a mantissa of 0 in the least significant digit.
-
-Table of representable values in F4E2M1:
-
-Note: x is sign bit
-| Binary | Value ( + / - ) 
-| x000   | 0.0
-| x001   | 0.5
-| x010   | 1.0
-| x011   | 1.5
-| x100   | 2.0
-| x101   | 3.0
-| x110   | 4.0
-| x111   | 6.0
-
-Conversion procedure: 
-
-Step 1: Clamp to max f4 value
-
-Step 2: convert exponent, if signed int comparison <= 0, set 0
-
-Step 3: if mantissa[1:] greater than 1000000, add 1
-
-Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
-Create bias adjusted exponent, E_1 <- E_0 - 126
-If E_0 <= 0111 1110
-    M_1 <- 0, E_1 <- 00
-    end
-if E_1 == 00 (special case for almost subnormal)
-    if we must round up (M_0 >= 10000000000000000000000)
-        M_1 <- 0 
-        E_1 <- 01
-    else
-        M_1 <- 1
-    end
-Else if E_1 > 00
-    roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
-    if roundToEven
-        M_1 <- 0 
-    else
-        M_1 <- 1
-    If M_0 >= 11000000000000000000000
-        increment E_1
-    If E_1 > 11 (saturate if beyond range)
-        M_1 <- 1, E_1 <- 11
-end
-*/
+
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - ) 
+/// | x000   | 0.0
+/// | x001   | 0.5
+/// | x010   | 1.0
+/// | x011   | 1.5
+/// | x100   | 2.0
+/// | x101   | 3.0
+/// | x110   | 4.0
+/// | x111   | 6.0
+///
+/// Conversion procedure: 
+///   Step 1: Clamp to representable bounds.
+///   Step 2: Convert exponent by adjusting bias.
+///   Step 3: Set mantissa to first bit.
+///   Step 4: Special consideration for subnormal and zero exponent.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
-    Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    
-    Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
-    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
-    Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
-    Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
-    Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
-    Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
-    Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
-    Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
-    Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
-    Value cF32MantissaWidth = c0x00000017; // 23
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
     Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32SignExpWidth = c0x00000009;  // 9
-    Value cF32FirstBitMask = c0x00400000;
-    Value cF32Last22BitMask = c0x003fffff;
-    Value cF32MantissaMask = c0x007fffff;
-
+    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+    
     // Step 1: Clamp to bounds.
     Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
     Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
-    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
-    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
-    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
+    operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
   
     // Step 2: Convert exponent by adjusting bias.
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
-    Value biasAdjustment = c0x0000007e; // 126
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
     Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
     
-    // Step 0: Special consideration for conversion to 0.5.
-    Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
-    Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
-    Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value isSubnormal =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-
     // Step 3: Set mantissa to first bit.
-    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
-    Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
+    Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+    
+    // Step 4: Special consideration for conversion to 0.5.
+    Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    Value isNegOneExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value isNonZeroMan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+    Value isZeroExp = 
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    
+    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+    f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
 
-    // Step 4: Round up if necessary.
-    Value cRound = c0x00200000; // 010 0000...
+    // Step 5: Round up if necessary.
+    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    shouldRound =
+        b.create<arith::OrIOp>(shouldRound, isSubnormal);
     Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
     f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 
-    // Regular case
-
-    
-
-    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
-    Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
-    // If we round up or down to even, set mantissa to 0
-    Value shouldRoundUp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
-    Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
-                                                    man23Bits, cRoundDown);
-    // dont need to worry about saturation this way
-    f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
-    Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
-    Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
-    f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
-    f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
-
-    // Bordering subnormal
-    Value cSubnormalRoundUp =
-        createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
-    Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
-    Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
-    Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
-                                                  man23Bits, cSubnormalRoundUp);
-    f4EdgeRounded =
-        b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
-    Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
-                                           cSubnormalExp);
-
-    // Subnormal
-    Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
-    Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
-                                          cSubnormalExp);
-
-    // create constants to extract mantissa / exponent
-    Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-    Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
-    // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
-
-    // create constants for NaNs
-    Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
-    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
-    Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
-    Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
-    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
-
-    Value isNan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
-    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
-    }
+    Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
     rewriter.replaceOp(op, result);
     return success();
   }

>From 30db57067002595e314a2ebe03b21ca20a87843f Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 10:06:34 -0700
Subject: [PATCH 4/8] PR Review round 1

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 51 ++++++++++---------
 1 file changed, 26 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e35154e8ae32b..199f7c6d2a34d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,9 +34,9 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
-/// Create an float constant.
-static Value createFloatConst(Location loc, Type type, float value,
-                         PatternRewriter &rewriter) {
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, APFloat value,
+                              PatternRewriter &rewriter) {
   auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     return rewriter.create<arith::ConstantOp>(
@@ -416,8 +416,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
-      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
     }
 
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -425,58 +425,59 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
-    // Constants
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
-    Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
-    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
-    
+
     // Step 1: Clamp to bounds.
-    Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
-    Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
-    Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
-    operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
+    Value cHigherBound =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
+    Value cLowerBound =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
+    Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
+    operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
-  
+
     // Step 2: Convert exponent by adjusting bias.
-    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+    Value cF4MantissaWidth = c0x1;         // 1
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
     
     // Step 3: Set mantissa to first bit.
+    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
     
     // Step 4: Special consideration for conversion to 0.5.
+    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
     Value isSubnormal =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
     Value isNegOneExp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
     Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
     Value isNonZeroMan =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
-    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
-    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
     Value isZeroExp = 
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
     
+    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
     Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
     subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
     f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
-
+    
     // Step 5: Round up if necessary.
+    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
     Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =

>From c4e66f0d0ad1e18c510f3a5194b1a6949ce04fd6 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 12:31:06 -0700
Subject: [PATCH 5/8] adding extf implementation

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 117 ++++++++++++++----
 1 file changed, 94 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 199f7c6d2a34d..889f1f17f0d82 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -334,6 +334,70 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
+        !llvm::isa<Float32Type>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value cZero =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+    Value cHalf =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+    Value mantissaBitmask = c0x1;
+    Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+    Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+    Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+    f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
+
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+    Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+    f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+    Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+    f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+    f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
+
+    Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+    Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+    // Special consideration for subnormal exp (exp == 0).
+    Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                f32ExpBits, biasAdjustment);
+    Value isManSet =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+    Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+    f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
+
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -389,7 +453,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// Table of representable values in F4E2M1:
 ///
 /// Note: x is sign bit
-/// | Binary | Value ( + / - ) 
+/// | Binary | Value ( + / - )
 /// | x000   | 0.0
 /// | x001   | 0.5
 /// | x010   | 1.0
@@ -399,12 +463,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// | x110   | 4.0
 /// | x111   | 6.0
 ///
-/// Conversion procedure: 
+/// Conversion procedure:
 ///   Step 1: Clamp to representable bounds.
 ///   Step 2: Convert exponent by adjusting bias.
 ///   Step 3: Set mantissa to first bit.
 ///   Step 4: Special consideration for subnormal and zero exponent.
-///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
+///   subnormal.
 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -442,48 +507,54 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
     // Step 2: Convert exponent by adjusting bias.
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    Value cF4MantissaWidth = c0x1;         // 1
-    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value cF4MantissaWidth = c0x1; // 1
+    Value cF32MantissaWidth =
+        createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
-    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value biasAdjustedSignExp =
+        b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
     Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
     f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
-    
+
     // Step 3: Set mantissa to first bit.
-    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value cF32FirstBitMask =
+        createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
-    
+
     // Step 4: Special consideration for conversion to 0.5.
-    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value cF32MantissaMask =
+        createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
     Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
     Value isSubnormal =
-    b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
     Value isNegOneExp =
-    b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
     Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
-    Value isNonZeroMan =
-    b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
+                                                 man23Bits, c0x00000000);
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
-    Value isZeroExp = 
-    b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-    
+    Value isZeroExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
     Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
     Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
-    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    Value subResult =
+        b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
     subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
     f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
-    
+
     // Step 5: Round up if necessary.
-    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
-    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value cF32Last22BitMask =
+        createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
+    Value cRound =
+        createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
-    shouldRound =
-        b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
     Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
     f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
 

>From 1d8a9864161a66205e7c0a356efcbddb620e1af9 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 16 Jun 2025 22:28:47 -0700
Subject: [PATCH 6/8] add tests and fix various issues revealed by tests

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   3 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |   2 +
 .../Dialect/Arith/Transforms/ExpandOps.cpp    |  77 +++++--
 mlir/test/Dialect/Arith/expand-ops-scale.mlir | 159 ++++++++++++++
 mlir/test/Dialect/Arith/expand-ops.mlir       | 195 ++++--------------
 .../CPU/test-arith-expand-truncf-extf.mlir    |  73 +++++++
 6 files changed, 333 insertions(+), 176 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/expand-ops-scale.mlir
 create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index e0a4567d6f406..b03cf2db78041 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
+void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e14b2aeee1c69..c7370b83fdb6c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
               "Enable the BF16 expansion patterns">,
        Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
               "Enable the F8E8M0 expansion patterns">,
+       Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
+              "Enable the F4E2M1 expansion patterns">,
   ];
 }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 889f1f17f0d82..aef995143112a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -345,9 +345,8 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
-        !llvm::isa<Float32Type>(resultETy)) {
-      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
+    if (!isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
     }
 
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -357,8 +356,9 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
 
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
     Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
     Value cZero =
         createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
     Value cHalf =
@@ -370,29 +370,33 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 
     Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
     Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
-    f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
+    f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
 
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
     Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
     f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
     Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
     f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
-    f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
-    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
+    Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
 
     Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
     Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+    f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
     f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
 
-    // Special consideration for subnormal exp (exp == 0).
+    // Special consideration for subnormal exponent (exp == 00).
     Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
                                                 f32ExpBits, biasAdjustment);
     Value isManSet =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
     Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
-    f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
 
     Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+    if (!isa<Float32Type>(resultETy)) {
+      result = b.create<arith::TruncFOp>(resultETy, operand);
+    }
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -481,8 +485,11 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
-      return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
+    if (!isa<Float32Type>(operandETy)) {
+      operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
+    }
+    if (!isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
     }
 
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -491,20 +498,28 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
     Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
     Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
     Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
     Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
 
-    // Step 1: Clamp to bounds.
+    // Step 0: Clamp to bounds.
     Value cHigherBound =
         createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
     Value cLowerBound =
         createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
-    Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
-    operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
+    Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
+    operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
 
+    // Step 1: Set sign bit.
+    Value cF32ExpManWidth =
+        createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
+    Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
+    Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
+    Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+
     // Step 2: Convert exponent by adjusting bias.
     Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
     Value cF4MantissaWidth = c0x1; // 1
@@ -513,8 +528,9 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustedSignExp =
         b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
-    Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
-    f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
+    Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
+    f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
 
     // Step 3: Set mantissa to first bit.
     Value cF32FirstBitMask =
@@ -522,7 +538,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
-    Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+    f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
 
     // Step 4: Special consideration for conversion to 0.5.
     Value cF32MantissaMask =
@@ -538,7 +554,6 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
     Value isZeroExp =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-
     Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
     Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
     Value subResult =
@@ -719,16 +734,24 @@ struct ArithExpandOpsPass
     if (includeF8E8M0) {
       arith::populateExpandF8E8M0Patterns(patterns);
     }
+    if (includeF4E2M1) {
+      arith::populateExpandF4E2M1Patterns(patterns);
+    }
 
     target.addDynamicallyLegalOp<arith::ExtFOp>(
       [=](arith::ExtFOp op) {
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) 
+        if (includeBf16) {
           legalTypes &= !(inETy.isBF16() && outETy.isF32());
-        if (includeF8E8M0)
+        } 
+        if (includeF8E8M0) {
           legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+        } 
+        if (includeF4E2M1) {
+          legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
+        }
         return legalTypes;
       });
 
@@ -737,10 +760,15 @@ struct ArithExpandOpsPass
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) 
+        if (includeBf16) {
           legalTypes &= !(inETy.isF32() && outETy.isBF16());
-        if (includeF8E8M0) 
+        }
+        if (includeF8E8M0) {
           legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
+        }
+        if (includeF4E2M1) {
+          legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
+        }
         return legalTypes;
       });
 
@@ -765,6 +793,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
+  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
   patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
       patterns.getContext());
diff --git a/mlir/test/Dialect/Arith/expand-ops-scale.mlir b/mlir/test/Dialect/Arith/expand-ops-scale.mlir
new file mode 100644
index 0000000000000..0b244e4eed784
--- /dev/null
+++ b/mlir/test/Dialect/Arith/expand-ops-scale.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// CHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
+// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+
+// CHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// CHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
+// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
+// CHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+// CHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
+// CHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
+// CHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
+// CHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+// CHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
+// CHECK: return
+
+// -----
+func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
+    return %0 : vector<4xf4E2M1FN>
+}
+// CHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
+// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: return
+
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+    return %0 : f32 
+}
+
+// CHECK-LABEL: @scaling_extf_to_f32
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+    return %0 : f32 
+}
+
+// CHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
+    return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> 
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f16
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> 
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_bf16
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> 
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
+// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index db1349feaff3a..62e059ccbe8de 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s 
-// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true include-f4e2m1=true" -verify-diagnostics -split-input-file | FileCheck %s 
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -310,64 +309,6 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 
 // -----
 
-func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
-    %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
-    return %0 : f4E2M1FN
-}
-
-// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
-// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
-// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
-    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
-    return %0 : vector<4xf6E3M2FN>
-}
-
-// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
-// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
-// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
-// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
-
-// -----
-
-func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
-    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
-    return %0 : vector<4xf6E3M2FN>
-}
-// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
-// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
-// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
-// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
-
-// -----
-
-func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
-    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
-    return %0 : f4E2M1FN
-}
-// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
-// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
-// SCHECK: return
-
-// -----
-func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
-    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
-    return %0 : vector<4xf4E2M1FN>
-}
-// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
-// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: return
-
-// -----
-
 func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
     // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
     %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
@@ -446,33 +387,6 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
 
 // -----
 
-func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
-    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
-    return %0 : f32 
-}
-
-// SCHECK-LABEL: @scaling_extf_to_f32
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
-    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
-    return %0 : f32 
-}
-
-// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
-// SCHECK: return %[[RESULT]]
-
-// -----
-
 func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
     // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
     %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
@@ -481,73 +395,6 @@ func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f
 
 // -----
 
-func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
-    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
-    return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> 
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
-    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
-    return %0 : vector<4xf16>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f16
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> 
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
-    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
-    return %0 : vector<4xbf16>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_bf16
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> 
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
-    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
-    return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
-func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
-    %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
-    return %0 : vector<4xf32>
-}
-
-// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
-// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
-// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
-// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
-// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
-// SCHECK: return %[[RESULT]]
-
-// -----
-
 func.func @maxsi(%a: i32, %b: i32) -> i32 {
   %result = arith.maxsi %a, %b : i32
   return %result : i32
@@ -593,3 +440,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
 // CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @truncf_f32_to_f4E2M1FN(%arg0 : f32) -> f4E2M1FN {
+    %0 = arith.truncf %arg0 : f32 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// CHECK-LABEL: @truncf_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f32_to_f4E2M1FN(%arg0 : vector<4xf32>) -> vector<4xf4E2M1FN> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf4E2M1FN>
+    return %0 : vector<4xf4E2M1FN>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f4E2M1FN
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @extf_f4E2M1FN_to_f32(%arg0 : f4E2M1FN) -> f32 {
+    %0 = arith.extf %arg0 : f4E2M1FN to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f4E2M1FN_to_f32(%arg0 : vector<4xf4E2M1FN>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f4E2M1FN_to_f32
+// CHECK-NOT: arith.extf
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
new file mode 100644
index 0000000000000..6e76968c70e5f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -0,0 +1,73 @@
+// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
+
+// RUN: mlir-opt %s --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm \
+// RUN:             --arith-expand="include-f4e2m1=true" \
+// RUN:             --convert-arith-to-llvm -reconcile-unrealized-casts | \
+// RUN:   mlir-runner -e entry --entry-point-result=void \
+// RUN:               --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+func.func @check_extf(%in : f4E2M1FN) -> () {
+  %res = arith.extf %in : f4E2M1FN to f32
+  vector.print %res : f32
+  return
+}
+
+// See https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+// for details on F4E2M1 representation 
+func.func @check_truncf(%in : f32) -> () {
+  %trunc = arith.truncf %in : f32 to f4E2M1FN
+  %bitcast = arith.bitcast %trunc : f4E2M1FN to i4
+  %res = arith.extui %bitcast : i4 to i64
+  vector.print %res : i64
+  return
+}
+
+func.func @entry() {
+  %zero = arith.constant 0.0 : f32
+  %half = arith.constant 0.5 : f32
+  %one = arith.constant 1.0 : f32
+  %max = arith.constant 6.0 : f32
+  %min = arith.constant -6.0 : f32
+  %lowerThanMin = arith.constant -1000000.0 : f32
+  %higherThanMax = arith.constant 1000000.0 : f32
+  %mustRound = arith.constant -3.14 : f32
+  %nan = arith.constant 0x7f80000 : f32
+
+  // CHECK: 0
+  func.call @check_truncf(%zero) : (f32) -> ()
+  // CHECK: 1
+  func.call @check_truncf(%half) : (f32) -> ()
+  // CHECK: 2
+  func.call @check_truncf(%one) : (f32) -> ()
+  // CHECK: 7
+  func.call @check_truncf(%max) : (f32) -> ()
+  // CHECK: 15
+  func.call @check_truncf(%min) : (f32) -> ()
+  // CHECK: 7
+  func.call @check_truncf(%higherThanMax) : (f32) -> ()
+  // CHECK: 15
+  func.call @check_truncf(%lowerThanMin) : (f32) -> ()
+  // CHECK: 13
+  func.call @check_truncf(%mustRound) : (f32) -> ()
+  // CHECK: 0
+  func.call @check_truncf(%nan) : (f32) -> ()
+
+  // CHECK: 0
+  %zeroF4 = arith.truncf %zero : f32 to f4E2M1FN
+  func.call @check_extf(%zeroF4) : (f4E2M1FN) -> ()
+  // CHECK: 0.5
+  %halfF4 = arith.truncf %half : f32 to f4E2M1FN
+  func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+  // CHECK: 6
+  %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
+  func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+  // CHECK: -6
+  %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
+  func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
+  // CHECK: -3
+  %mustRoundF4 = arith.truncf %mustRound : f32 to f4E2M1FN
+  func.call @check_extf(%mustRoundF4) : (f4E2M1FN) -> ()
+  return
+}

>From 19394a89cfdbe082bfbf2a45555fa44c5a348fb3 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 18 Jun 2025 19:28:02 +0000
Subject: [PATCH 7/8] Adding lookup implementation for arith.extf + formatting
 fixes

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 126 +++++++++++++-----
 .../CPU/test-arith-expand-truncf-extf.mlir    |   6 +-
 2 files changed, 93 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index aef995143112a..aec2001d64443 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 namespace mlir {
 namespace arith {
@@ -240,9 +241,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!operandETy.isBF16() || !resultETy.isF32()) {
+    if (!operandETy.isBF16() || !resultETy.isF32())
       return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
-    }
 
     Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -270,9 +270,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!operandETy.isF32() || !resultETy.isBF16()) {
+    if (!operandETy.isF32() || !resultETy.isBF16())
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
-    }
 
     if (op.getRoundingmodeAttr()) {
       return rewriter.notifyMatchFailure(
@@ -336,6 +335,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
 
 struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
+  F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
   LogicalResult matchAndRewrite(arith::ExtFOp op,
                                 PatternRewriter &rewriter) const final {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -402,6 +403,71 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   }
 };
 
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (isa<ShapedType>(operandTy))
+      return failure();
+
+    if (!isa<Float4E2M1FNType>(operandETy))
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+    SmallVector<int> values = {
+        0x00000000, // 0.0
+        0x3f000000, // 0.5
+        0x3f800000, // 1.0
+        0x3fc00000, // 1.5
+        0x40000000, // 2.0
+        0x40400000, // 3.0
+        0x40800000, // 4.0
+        0x40c00000  // 6.0
+    };
+    // auto type = RankedTensorType::get({8}, b.getI32Type());
+    VectorType type = VectorType::get({8}, b.getI32Type());
+    SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
+        values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
+    Value lookupTable = b.create<arith::ConstantOp>(
+        DenseIntElementsAttr::get(type, lookupTableAttr));
+
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
+
+    Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
+    Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
+    Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
+    Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
+
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+    Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
+    Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
+    signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
+
+    Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
+    Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (!isa<Float32Type>(resultETy))
+      result = b.create<arith::TruncFOp>(resultETy, operand);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -413,9 +479,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+    if (!llvm::isa<Float8E8M0FNUType>(operandETy))
       return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
-    }
 
     Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -485,12 +550,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!isa<Float32Type>(operandETy)) {
+    if (!isa<Float32Type>(operandETy))
       operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
-    }
-    if (!isa<Float4E2M1FNType>(resultETy)) {
+    if (!isa<Float4E2M1FNType>(resultETy))
       return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
-    }
 
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
     Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
@@ -509,8 +572,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
         createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
     Value cLowerBound =
         createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
-    Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
-    operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
+    Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
+    operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
 
     // Step 1: Set sign bit.
@@ -594,14 +657,12 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(resultTy);
-    if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+    if (!llvm::isa<Float8E8M0FNUType>(resultETy))
       return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
-    }
 
-    if (op.getRoundingmodeAttr()) {
+    if (op.getRoundingmodeAttr())
       return rewriter.notifyMatchFailure(
           op, "only applicable to default rounding mode.");
-    }
 
     Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -711,6 +772,8 @@ struct ArithExpandOpsPass
     arith::populateArithExpandOpsPatterns(patterns);
 
     target.addLegalDialect<arith::ArithDialect>();
+    target.addLegalDialect<vector::VectorDialect>();
+
     // clang-format off
     target.addIllegalOp<
       arith::CeilDivSIOp,
@@ -728,30 +791,24 @@ struct ArithExpandOpsPass
       arith::ScalingTruncFOp
     >();
 
-    if (includeBf16) {
+    if (includeBf16)
       arith::populateExpandBFloat16Patterns(patterns);
-    }
-    if (includeF8E8M0) {
+    if (includeF8E8M0)
       arith::populateExpandF8E8M0Patterns(patterns);
-    }
-    if (includeF4E2M1) {
+    if (includeF4E2M1)
       arith::populateExpandF4E2M1Patterns(patterns);
-    }
 
     target.addDynamicallyLegalOp<arith::ExtFOp>(
       [=](arith::ExtFOp op) {
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) {
+        if (includeBf16)
           legalTypes &= !(inETy.isBF16() && outETy.isF32());
-        } 
-        if (includeF8E8M0) {
+        if (includeF8E8M0)
           legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
-        } 
-        if (includeF4E2M1) {
+        if (includeF4E2M1)
           legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
-        }
         return legalTypes;
       });
 
@@ -760,15 +817,12 @@ struct ArithExpandOpsPass
         Type inETy = getElementTypeOrSelf(op.getOperand().getType());
         Type outETy = getElementTypeOrSelf(op.getType());
         bool legalTypes = true;
-        if (includeBf16) {
+        if (includeBf16)
           legalTypes &= !(inETy.isF32() && outETy.isBF16());
-        }
-        if (includeF8E8M0) {
+        if (includeF8E8M0)
           legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
-        }
-        if (includeF4E2M1) {
+        if (includeF4E2M1)
           legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
-        }
         return legalTypes;
       });
 
@@ -794,8 +848,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
 }
 
 void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
-  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
-      patterns.getContext());
+  patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
+               F4E2M1TruncFOpConverter>(patterns.getContext());
 }
 
 void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 6e76968c70e5f..9c310d80d4c2d 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -1,9 +1,9 @@
 // Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
 
-// RUN: mlir-opt %s --convert-vector-to-llvm \
-// RUN:             --convert-func-to-llvm \
+// RUN: mlir-opt %s --convert-func-to-llvm \
 // RUN:             --arith-expand="include-f4e2m1=true" \
-// RUN:             --convert-arith-to-llvm -reconcile-unrealized-casts | \
+// RUN:             --convert-arith-to-llvm --convert-vector-to-llvm \
+// RUN:             --reconcile-unrealized-casts | \
 // RUN:   mlir-runner -e entry --entry-point-result=void \
 // RUN:               --shared-libs=%mlir_c_runner_utils | \
 // RUN:   FileCheck %s --match-full-lines

>From f554b7c22612190d305964ff416f8a277247359a Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 19 Jun 2025 05:03:27 +0000
Subject: [PATCH 8/8] improving extf implementation

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 250 +++++++-----------
 1 file changed, 102 insertions(+), 148 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index aec2001d64443..473980b44b66a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -11,9 +11,11 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include <cstdint>
 
 namespace mlir {
 namespace arith {
@@ -333,133 +335,92 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+/// In this implementation of extf we take advantage of some key patterns we
+/// notice between the binary representation of an F4E2M1 value and its
+/// corresponding value in fp32.
+///
+/// Note: x is sign bit
+/// | Binary | F4E2M1 | fp32
+/// | x000   | 0.0    | 0000 0000 00
+/// | x001   | 0.5    | 0011 1111 00
+/// | x010   | 1.0    | 0011 1111 10
+/// | x011   | 1.5    | 0011 1111 11
+/// | x100   | 2.0    | 0010 0000 00
+/// | x101   | 3.0    | 0010 0000 01
+/// | x110   | 4.0    | 0010 0000 10
+/// | x111   | 6.0    | 0010 0000 11
+///
+/// 1) There are only two versions of bits [25:31] in the fp32 result
+///     F4E2M1 bits[2:3] decide whether:
+///       - FP32 bits[25:31] = 0011 1111
+///       - FP32 bits[25:31] = 0010 0000
+///     Exception is zero where
+///       - FP32 bits[25:31] = 0000 0000
+///
+/// 2) F4E2M1 bits[1:2] = FP32 bits[23:24]
+///     Exception is 0.5 where
+///       - F4E2M1 bits[1:2] = 01, FP32 bits[23:24] = 00
+///
+/// 3) F4E2M1 bits[4] = FP32 bits[32] (sign bits are equal)
+///
+/// 4) FP32 bits[1:22] = 0
 struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
-  F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
   LogicalResult matchAndRewrite(arith::ExtFOp op,
                                 PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value operand = op.getOperand();
-    Type operandTy = operand.getType();
-    Type resultTy = op.getType();
-    Type operandETy = getElementTypeOrSelf(operandTy);
-    Type resultETy = getElementTypeOrSelf(resultTy);
-
-    if (!isa<Float4E2M1FNType>(operandETy)) {
-      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
-    }
-
-    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
-    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
-    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
-
-    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
-
-    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
-    Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
-    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
-    Value cZero =
-        createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
-    Value cHalf =
-        createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
-
-    Value mantissaBitmask = c0x1;
-    Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
-    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
-
-    Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
-    Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
-    f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
-
-    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
-    Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
-    f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
-    Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
-    f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
-    Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
-    f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
-
-    Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
-    Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
-    f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
-    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
-
-    // Special consideration for subnormal exponent (exp == 00).
-    Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
-                                                f32ExpBits, biasAdjustment);
-    Value isManSet =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
-    Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
-
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
-    if (!isa<Float32Type>(resultETy)) {
-      result = b.create<arith::TruncFOp>(resultETy, operand);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
-      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
-  LogicalResult matchAndRewrite(arith::ExtFOp op,
-                                PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
     Value operand = op.getOperand();
     Type operandTy = operand.getType();
     Type resultTy = op.getType();
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (isa<ShapedType>(operandTy))
-      return failure();
-
     if (!isa<Float4E2M1FNType>(operandETy))
       return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
 
-    SmallVector<int> values = {
-        0x00000000, // 0.0
-        0x3f000000, // 0.5
-        0x3f800000, // 1.0
-        0x3fc00000, // 1.5
-        0x40000000, // 2.0
-        0x40400000, // 3.0
-        0x40800000, // 4.0
-        0x40c00000  // 6.0
-    };
-    // auto type = RankedTensorType::get({8}, b.getI32Type());
-    VectorType type = VectorType::get({8}, b.getI32Type());
-    SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
-        values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
-    Value lookupTable = b.create<arith::ConstantOp>(
-        DenseIntElementsAttr::get(type, lookupTableAttr));
-
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
     Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
-    Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
-
     Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
 
-    Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
-    Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
-    Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
-    Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
-
-    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
-    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
-    Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
-    Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
-    signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
-
-    Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
-    Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
-    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
+    Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
+    Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
+    Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+
+    // Set last Exponent bit and Mantissa.
+    Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
+    Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+    Value isHalf =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
+    bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
+    bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
+    bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+
+    // Set first 7 bits of Exponent.
+    Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
+    Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
+    Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
+    Value useLargerExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+    Value bits25To31 =
+        b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+    Value zeroExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
+    bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+
+    // Set sign.
+    Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
+    Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
+    Value negative =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
+    Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+
+    // Add segments together.
+    Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
+    Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
+    Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
     if (!isa<Float32Type>(resultETy))
       result = b.create<arith::TruncFOp>(resultETy, operand);
 
@@ -522,15 +483,15 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
 /// Table of representable values in F4E2M1:
 ///
 /// Note: x is sign bit
-/// | Binary | Value ( + / - )
-/// | x000   | 0.0
-/// | x001   | 0.5
-/// | x010   | 1.0
-/// | x011   | 1.5
-/// | x100   | 2.0
-/// | x101   | 3.0
-/// | x110   | 4.0
-/// | x111   | 6.0
+/// | Binary | F4E2M1 | fp32
+/// | x000   | 0.0    | 0000 0000 00
+/// | x001   | 0.5    | 0011 1111 00
+/// | x010   | 1.0    | 0011 1111 10
+/// | x011   | 1.5    | 0011 1111 11
+/// | x100   | 2.0    | 0010 0000 00
+/// | x101   | 3.0    | 0010 0000 01
+/// | x110   | 4.0    | 0010 0000 10
+/// | x111   | 6.0    | 0010 0000 11
 ///
 /// Conversion procedure:
 ///   Step 1: Clamp to representable bounds.
@@ -543,7 +504,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::TruncFOp op,
                                 PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
     Value operand = op.getOperand();
     Type operandTy = operand.getType();
     Type resultTy = op.getType();
@@ -560,34 +522,30 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
-    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
-    Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
-    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
-    Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
-    Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+    Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
+    Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
+    Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
+    Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
+    Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
 
     // Step 0: Clamp to bounds.
-    Value cHigherBound =
-        createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
-    Value cLowerBound =
-        createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
+    Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
+    Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
     Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
     operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
 
     // Step 1: Set sign bit.
-    Value cF32ExpManWidth =
-        createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
+    Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
     Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
     Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
     Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
 
     // Step 2: Convert exponent by adjusting bias.
-    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
-    Value cF4MantissaWidth = c0x1; // 1
-    Value cF32MantissaWidth =
-        createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
+    Value cF4MantissaWidth = c0x1;                                   // 1
+    Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
     Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
     Value biasAdjustedSignExp =
         b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
@@ -596,16 +554,14 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
 
     // Step 3: Set mantissa to first bit.
-    Value cF32FirstBitMask =
-        createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
     Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
     man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
     Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
     f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
 
     // Step 4: Special consideration for conversion to 0.5.
-    Value cF32MantissaMask =
-        createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
     Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
     Value isSubnormal =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
@@ -613,22 +569,20 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
     Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
     Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
-                                                 man23Bits, c0x00000000);
+                                                 man23Bits, zeroExpBits);
     Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
     Value isZeroExp =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
-    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
-    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+    Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
     Value subResult =
         b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
     subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
     f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
 
     // Step 5: Round up if necessary.
-    Value cF32Last22BitMask =
-        createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
-    Value cRound =
-        createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
+    Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
     Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
     Value shouldRound =
         b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
@@ -848,8 +802,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
 }
 
 void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
-  patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
-               F4E2M1TruncFOpConverter>(patterns.getContext());
+  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
+      patterns.getContext());
 }
 
 void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {



More information about the Mlir-commits mailing list