[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)

Umang Yadav llvmlistbot at llvm.org
Wed May 21 10:05:05 PDT 2025


https://github.com/umangyadav updated https://github.com/llvm/llvm-project/pull/140332

>From 43daddc7c662ae678570050d5402b67c49229da0 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:28:12 +0000
Subject: [PATCH 1/3] Add arith expansion of f8E8M0 type for extf/trunc ops

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   3 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |   2 +
 mlir/include/mlir/IR/Types.h                  |   1 +
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 138 ++++++++++++++++--
 mlir/lib/IR/Types.cpp                         |   2 +-
 mlir/test/Dialect/Arith/expand-ops.mlir       | 130 ++++++++++++++++-
 6 files changed, 265 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..f97efa52bbaf6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -17,6 +17,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
   let options = [
     Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
            "Enable the BF16 expansion patterns">,
+    Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+           "Enable the F8E8M0 expansion patterns">,
   ];
 }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isF8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Constant used to make the rounding bias.
     Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
     // Constant used to generate a quiet NaN.
-    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
     // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
     Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+    // create constants for NaNs
+    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.isBF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.isF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(resultTy);
+    if (!resultETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+    }
+    if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+    }
+
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+    if (!operandETy.isF32()) {
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    }
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
       arith::MinNumFOp
     >();
 
-    if (includeBf16) {
+    if(includeBf16) {
       arith::populateExpandBFloat16Patterns(patterns);
+    }
+    if(includeF8E8M0) {
+      arith::populateExpandF8E8M0Patterns(patterns);
+    }
+    if (includeBf16 || includeF8E8M0) {
       target.addDynamicallyLegalOp<arith::ExtFOp>(
-        [](arith::ExtFOp op) {
+        [=](arith::ExtFOp op) {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isBF16() && outETy.isF32());
+          if(includeBf16 && includeF8E8M0)
+            return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+          if(includeBf16)
+            return !(inETy.isBF16() && outETy.isF32());
+          return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
         });
 
       target.addDynamicallyLegalOp<arith::TruncFOp>(
-        [](arith::TruncFOp op)  {
+        [=](arith::TruncFOp op)  {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isF32() && outETy.isBF16());
+          if(includeBf16 && includeF8E8M0) 
+            return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
+          if(includeBf16)
+            return !(inETy.isF32() && outETy.isBF16());
+          return 
+            !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
         });
     }
-
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 //===----------------------------------------------------------------------===//
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
 // CHECK-LABEL: @truncf_vector_f32
 // CHECK-NOT: arith.truncf
 
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f16
+    return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
 // -----
 
 func.func @maxsi(%a: i32, %b: i32) -> i32 {

>From 25a8fddcb3d5e72ad127e235204cbad9a7ad9377 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:55:58 +0000
Subject: [PATCH 2/3] Fix formatting

---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index f97efa52bbaf6..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,11 +14,11 @@ include "mlir/Pass/PassBase.td"
 def ArithExpandOpsPass : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let dependentDialects = ["vector::VectorDialect"];
-  let options = [
-    Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
-           "Enable the BF16 expansion patterns">,
-    Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
-           "Enable the F8E8M0 expansion patterns">,
+  let options =
+      [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+              "Enable the BF16 expansion patterns">,
+       Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+              "Enable the F8E8M0 expansion patterns">,
   ];
 }
 

>From a3dee4858dcd7f5e1f5d223ffb04b372f5148982 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Tue, 20 May 2025 18:36:15 +0000
Subject: [PATCH 3/3] Address review comments

---
 mlir/include/mlir/IR/Types.h                  |  1 -
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 80 +++++++++----------
 mlir/lib/IR/Types.cpp                         |  1 -
 3 files changed, 37 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 55a7c6bb11784..4ffdbfa5b1224 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,7 +109,6 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
-  bool isF8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index f5240cf92bdc4..762cf91092f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   LogicalResult matchAndRewrite(arith::ExtFOp op,
                                 PatternRewriter &rewriter) const final {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto operand = op.getOperand();
+    Value operand = op.getOperand();
     Type operandTy = operand.getType();
     Type resultTy = op.getType();
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultETy = getElementTypeOrSelf(resultTy);
 
-    if (!operandETy.isF8E8M0FNU()) {
+    if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
       return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
     }
 
-    if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
-      return rewriter.notifyMatchFailure(
-          op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
-    }
-
     Type i8Ty = b.getI8Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     // select for NaNs
     f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
     Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
-    if (resultETy.isBF16()) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
-    } else if (resultETy.isF16()) {
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
       result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   LogicalResult matchAndRewrite(arith::TruncFOp op,
                                 PatternRewriter &rewriter) const final {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto operand = op.getOperand();
+    Value operand = op.getOperand();
     Type operandTy = operand.getType();
     Type operandETy = getElementTypeOrSelf(operandTy);
     Type resultTy = op.getType();
     Type resultETy = getElementTypeOrSelf(resultTy);
-    if (!resultETy.isF8E8M0FNU()) {
+    if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
       return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
     }
-    if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
-      return rewriter.notifyMatchFailure(
-          op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
-    }
 
     if (op.getRoundingmodeAttr()) {
       return rewriter.notifyMatchFailure(
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       i32Ty = shapedTy.clone(i32Ty);
       f32Ty = shapedTy.clone(f32Ty);
     }
-    if (!operandETy.isF32()) {
+    if (operandETy.getIntOrFloatBitWidth() < 32) {
       operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    } else if (operandETy.getIntOrFloatBitWidth() > 32) {
+      operand = b.create<arith::TruncFOp>(f32Ty, operand);
     }
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
       arith::MinNumFOp
     >();
 
-    if(includeBf16) {
+    if (includeBf16) {
       arith::populateExpandBFloat16Patterns(patterns);
     }
-    if(includeF8E8M0) {
+    if (includeF8E8M0) {
       arith::populateExpandF8E8M0Patterns(patterns);
     }
-    if (includeBf16 || includeF8E8M0) {
-      target.addDynamicallyLegalOp<arith::ExtFOp>(
-        [=](arith::ExtFOp op) {
-          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-          Type outETy = getElementTypeOrSelf(op.getType());
-          if(includeBf16 && includeF8E8M0)
-            return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
-          if(includeBf16)
-            return !(inETy.isBF16() && outETy.isF32());
-          return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
-        });
-
-      target.addDynamicallyLegalOp<arith::TruncFOp>(
-        [=](arith::TruncFOp op)  {
-          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-          Type outETy = getElementTypeOrSelf(op.getType());
-          if(includeBf16 && includeF8E8M0) 
-            return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
-          if(includeBf16)
-            return !(inETy.isF32() && outETy.isBF16());
-          return 
-            !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
-        });
-    }
+
+    target.addDynamicallyLegalOp<arith::ExtFOp>(
+      [=](arith::ExtFOp op) {
+        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+        Type outETy = getElementTypeOrSelf(op.getType());
+        bool legalTypes = true;
+        if(includeBf16) 
+          legalTypes &= !(inETy.isBF16() && outETy.isF32());
+        if(includeF8E8M0)
+          legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+        return legalTypes;
+      });
+
+    target.addDynamicallyLegalOp<arith::TruncFOp>(
+      [=](arith::TruncFOp op)  {
+        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+        Type outETy = getElementTypeOrSelf(op.getType());
+        bool legalTypes = true;
+        if(includeBf16) 
+          legalTypes &= !(inETy.isF32() && outETy.isBF16());
+        if(includeF8E8M0) 
+          legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
+        return legalTypes;
+      });
+
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 975b26ae4369f..ab6f5eda1ad7d 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,6 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 //===----------------------------------------------------------------------===//
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }



More information about the Mlir-commits mailing list