[Mlir-commits] [mlir] ffa5ce0 - Add arith expansion of f8E8M0 type for extf/trunc ops (#140332)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 22 13:36:03 PDT 2025


Author: Umang Yadav
Date: 2025-05-22T15:36:00-05:00
New Revision: ffa5ce04d0a5440f939881e0e329a173d486dd68

URL: https://github.com/llvm/llvm-project/commit/ffa5ce04d0a5440f939881e0e329a173d486dd68
DIFF: https://github.com/llvm/llvm-project/commit/ffa5ce04d0a5440f939881e0e329a173d486dd68.diff

LOG: Add arith expansion of f8E8M0 type for extf/trunc ops (#140332)

F8E8M0 floating type is supposed to represent biased exponent bits of
F32 type in OCP Micro scaling floating point formats.


https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR expands `arith.truncf` and `arith.extf` to support this
behavior.

For the `arith.truncf` thing to note here is that F8E8M0FNU type has one
NaN representation which is encoded as `0xFF`. Therefore alll kinds of
NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU
doesn't have a sign bit therefore it is a lossy and irreversible
downcast.

cc: @krzysz00  @MaheshRavishankar @Muzammiluddin-Syed-ECE

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
 def ArithExpandOpsPass : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let dependentDialects = ["vector::VectorDialect"];
-  let options = [
-    Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
-           "Enable the BF16 expansion patterns">,
+  let options =
+      [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+              "Enable the BF16 expansion patterns">,
+       Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+              "Enable the F8E8M0 expansion patterns">,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..95546bb09e765 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Creates shapedType using shape from cloneFrom and base type from cloneTo
+static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
+  if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
+    return shapedTy.clone(cloneTo);
+  }
+  return cloneTo;
+}
+
 namespace {
 
 /// Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
       return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
     }
 
-    Type i16Ty = b.getI16Type();
-    Type i32Ty = b.getI32Type();
-    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i16Ty = shapedTy.clone(i16Ty);
-      i32Ty = shapedTy.clone(i32Ty);
-    }
+    Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
 
     Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
     Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
           op, "only applicable to default rounding mode.");
     }
 
-    Type i16Ty = b.getI16Type();
-    Type i32Ty = b.getI32Type();
-    Type f32Ty = b.getF32Type();
-    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i16Ty = shapedTy.clone(i16Ty);
-      i32Ty = shapedTy.clone(i32Ty);
-      f32Ty = shapedTy.clone(f32Ty);
-    }
+    Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
 
     // Algorithm borrowed from this excellent code:
     // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Constant used to make the rounding bias.
     Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
     // Constant used to generate a quiet NaN.
-    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
     // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
     Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+    // create constants for NaNs
+    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(resultTy);
+    if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+    }
+
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    if (operandETy.getIntOrFloatBitWidth() < 32) {
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    } else if (operandETy.getIntOrFloatBitWidth() > 32) {
+      operand = b.create<arith::TruncFOp>(f32Ty, operand);
+    }
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
 
     if (includeBf16) {
       arith::populateExpandBFloat16Patterns(patterns);
-      target.addDynamicallyLegalOp<arith::ExtFOp>(
-        [](arith::ExtFOp op) {
-          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-          Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isBF16() && outETy.isF32());
-        });
-
-      target.addDynamicallyLegalOp<arith::TruncFOp>(
-        [](arith::TruncFOp op)  {
-          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-          Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isF32() && outETy.isBF16());
-        });
     }
+    if (includeF8E8M0) {
+      arith::populateExpandF8E8M0Patterns(patterns);
+    }
+
+    target.addDynamicallyLegalOp<arith::ExtFOp>(
+      [=](arith::ExtFOp op) {
+        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+        Type outETy = getElementTypeOrSelf(op.getType());
+        bool legalTypes = true;
+        if (includeBf16) 
+          legalTypes &= !(inETy.isBF16() && outETy.isF32());
+        if (includeF8E8M0)
+          legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+        return legalTypes;
+      });
+
+    target.addDynamicallyLegalOp<arith::TruncFOp>(
+      [=](arith::TruncFOp op)  {
+        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+        Type outETy = getElementTypeOrSelf(op.getType());
+        bool legalTypes = true;
+        if (includeBf16) 
+          legalTypes &= !(inETy.isF32() && outETy.isBF16());
+        if (includeF8E8M0) 
+          legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
+        return legalTypes;
+      });
 
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
 // CHECK-LABEL: @truncf_vector_f32
 // CHECK-NOT: arith.truncf
 
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f16
+    return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
 // -----
 
 func.func @maxsi(%a: i32, %b: i32) -> i32 {


        


More information about the Mlir-commits mailing list