[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)
Umang Yadav
llvmlistbot at llvm.org
Wed May 21 10:05:05 PDT 2025
https://github.com/umangyadav updated https://github.com/llvm/llvm-project/pull/140332
>From 43daddc7c662ae678570050d5402b67c49229da0 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:28:12 +0000
Subject: [PATCH 1/3] Add arith expansion of f8E8M0 type for extf/trunc ops
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 3 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 2 +
mlir/include/mlir/IR/Types.h | 1 +
.../Dialect/Arith/Transforms/ExpandOps.cpp | 138 ++++++++++++++++--
mlir/lib/IR/Types.cpp | 2 +-
mlir/test/Dialect/Arith/expand-ops.mlir | 130 ++++++++++++++++-
6 files changed, 265 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..f97efa52bbaf6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -17,6 +17,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
let options = [
Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
"Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
+ bool isF8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
- Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+ Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.isBF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.isF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!resultETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+ if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+ }
+
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+ if (!operandETy.isF32()) {
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ }
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();
- if (includeBf16) {
+ if(includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
+ }
+ if(includeF8E8M0) {
+ arith::populateExpandF8E8M0Patterns(patterns);
+ }
+ if (includeBf16 || includeF8E8M0) {
target.addDynamicallyLegalOp<arith::ExtFOp>(
- [](arith::ExtFOp op) {
+ [=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isBF16() && outETy.isF32());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+ if(includeBf16)
+ return !(inETy.isBF16() && outETy.isF32());
+ return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
});
target.addDynamicallyLegalOp<arith::TruncFOp>(
- [](arith::TruncFOp op) {
+ [=](arith::TruncFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isF32() && outETy.isBF16());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
+ if(includeBf16)
+ return !(inETy.isF32() && outETy.isBF16());
+ return
+ !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
});
}
-
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+ patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f16
+ return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
// -----
func.func @maxsi(%a: i32, %b: i32) -> i32 {
>From 25a8fddcb3d5e72ad127e235204cbad9a7ad9377 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:55:58 +0000
Subject: [PATCH 2/3] Fix formatting
---
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index f97efa52bbaf6..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,11 +14,11 @@ include "mlir/Pass/PassBase.td"
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
- let options = [
- Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
- "Enable the BF16 expansion patterns">,
- Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
- "Enable the F8E8M0 expansion patterns">,
+ let options =
+ [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+ "Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
>From a3dee4858dcd7f5e1f5d223ffb04b372f5148982 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Tue, 20 May 2025 18:36:15 +0000
Subject: [PATCH 3/3] Address review comments
---
mlir/include/mlir/IR/Types.h | 1 -
.../Dialect/Arith/Transforms/ExpandOps.cpp | 80 +++++++++----------
mlir/lib/IR/Types.cpp | 1 -
3 files changed, 37 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 55a7c6bb11784..4ffdbfa5b1224 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,7 +109,6 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
- bool isF8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index f5240cf92bdc4..762cf91092f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto operand = op.getOperand();
+ Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!operandETy.isF8E8M0FNU()) {
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}
- if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
- return rewriter.notifyMatchFailure(
- op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
- }
-
Type i8Ty = b.getI8Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
// select for NaNs
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.isBF16()) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.isF16()) {
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
}
rewriter.replaceOp(op, result);
return success();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto operand = op.getOperand();
+ Value operand = op.getOperand();
Type operandTy = operand.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultTy = op.getType();
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!resultETy.isF8E8M0FNU()) {
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
}
- if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
- return rewriter.notifyMatchFailure(
- op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
- }
if (op.getRoundingmodeAttr()) {
return rewriter.notifyMatchFailure(
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}
- if (!operandETy.isF32()) {
+ if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ } else if (operandETy.getIntOrFloatBitWidth() > 32) {
+ operand = b.create<arith::TruncFOp>(f32Ty, operand);
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();
- if(includeBf16) {
+ if (includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
}
- if(includeF8E8M0) {
+ if (includeF8E8M0) {
arith::populateExpandF8E8M0Patterns(patterns);
}
- if (includeBf16 || includeF8E8M0) {
- target.addDynamicallyLegalOp<arith::ExtFOp>(
- [=](arith::ExtFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- if(includeBf16 && includeF8E8M0)
- return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
- if(includeBf16)
- return !(inETy.isBF16() && outETy.isF32());
- return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
- });
-
- target.addDynamicallyLegalOp<arith::TruncFOp>(
- [=](arith::TruncFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- if(includeBf16 && includeF8E8M0)
- return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
- if(includeBf16)
- return !(inETy.isF32() && outETy.isBF16());
- return
- !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
- });
- }
+
+ target.addDynamicallyLegalOp<arith::ExtFOp>(
+ [=](arith::ExtFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if(includeBf16)
+ legalTypes &= !(inETy.isBF16() && outETy.isF32());
+ if(includeF8E8M0)
+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+ return legalTypes;
+ });
+
+ target.addDynamicallyLegalOp<arith::TruncFOp>(
+ [=](arith::TruncFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if(includeBf16)
+ legalTypes &= !(inETy.isF32() && outETy.isBF16());
+ if(includeF8E8M0)
+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
+ return legalTypes;
+ });
+
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 975b26ae4369f..ab6f5eda1ad7d 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,6 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
More information about the Mlir-commits
mailing list