[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)
Umang Yadav
llvmlistbot at llvm.org
Thu May 22 06:18:08 PDT 2025
https://github.com/umangyadav updated https://github.com/llvm/llvm-project/pull/140332
>From 062a982bced9410c7af133d7c2363bdbc1980408 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:28:12 +0000
Subject: [PATCH 1/6] Add arith expansion of f8E8M0 type for extf/trunc ops
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 3 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 2 +
mlir/include/mlir/IR/Types.h | 1 +
.../Dialect/Arith/Transforms/ExpandOps.cpp | 138 ++++++++++++++++--
mlir/lib/IR/Types.cpp | 2 +-
mlir/test/Dialect/Arith/expand-ops.mlir | 130 ++++++++++++++++-
6 files changed, 265 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..f97efa52bbaf6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -17,6 +17,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
let options = [
Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
"Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
+ bool isF8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
- Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+ Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.isBF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.isF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!resultETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+ if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+ }
+
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+ if (!operandETy.isF32()) {
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ }
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();
- if (includeBf16) {
+ if(includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
+ }
+ if(includeF8E8M0) {
+ arith::populateExpandF8E8M0Patterns(patterns);
+ }
+ if (includeBf16 || includeF8E8M0) {
target.addDynamicallyLegalOp<arith::ExtFOp>(
- [](arith::ExtFOp op) {
+ [=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isBF16() && outETy.isF32());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+ if(includeBf16)
+ return !(inETy.isBF16() && outETy.isF32());
+ return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
});
target.addDynamicallyLegalOp<arith::TruncFOp>(
- [](arith::TruncFOp op) {
+ [=](arith::TruncFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isF32() && outETy.isBF16());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
+ if(includeBf16)
+ return !(inETy.isF32() && outETy.isBF16());
+ return
+ !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
});
}
-
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+ patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f16
+ return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
// -----
func.func @maxsi(%a: i32, %b: i32) -> i32 {
>From da2545847071d1802ed0f7979b5ecdabbf624bc9 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Sat, 17 May 2025 01:55:58 +0000
Subject: [PATCH 2/6] Fix formatting
---
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index f97efa52bbaf6..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,11 +14,11 @@ include "mlir/Pass/PassBase.td"
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
- let options = [
- Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
- "Enable the BF16 expansion patterns">,
- Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
- "Enable the F8E8M0 expansion patterns">,
+ let options =
+ [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+ "Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
>From 49683cc38e7f79c9021d85e141d51ed889a175ca Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Tue, 20 May 2025 18:36:15 +0000
Subject: [PATCH 3/6] Address review comments
---
mlir/include/mlir/IR/Types.h | 1 -
.../Dialect/Arith/Transforms/ExpandOps.cpp | 80 +++++++++----------
mlir/lib/IR/Types.cpp | 1 -
3 files changed, 37 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 55a7c6bb11784..4ffdbfa5b1224 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,7 +109,6 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
- bool isF8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index f5240cf92bdc4..762cf91092f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto operand = op.getOperand();
+ Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!operandETy.isF8E8M0FNU()) {
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}
- if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
- return rewriter.notifyMatchFailure(
- op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
- }
-
Type i8Ty = b.getI8Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
// select for NaNs
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
- if (resultETy.isBF16()) {
- result = b.create<arith::TruncFOp>(resultTy, result);
- } else if (resultETy.isF16()) {
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
}
rewriter.replaceOp(op, result);
return success();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto operand = op.getOperand();
+ Value operand = op.getOperand();
Type operandTy = operand.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultTy = op.getType();
Type resultETy = getElementTypeOrSelf(resultTy);
- if (!resultETy.isF8E8M0FNU()) {
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
}
- if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
- return rewriter.notifyMatchFailure(
- op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
- }
if (op.getRoundingmodeAttr()) {
return rewriter.notifyMatchFailure(
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}
- if (!operandETy.isF32()) {
+ if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ } else if (operandETy.getIntOrFloatBitWidth() > 32) {
+ operand = b.create<arith::TruncFOp>(f32Ty, operand);
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();
- if(includeBf16) {
+ if (includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
}
- if(includeF8E8M0) {
+ if (includeF8E8M0) {
arith::populateExpandF8E8M0Patterns(patterns);
}
- if (includeBf16 || includeF8E8M0) {
- target.addDynamicallyLegalOp<arith::ExtFOp>(
- [=](arith::ExtFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- if(includeBf16 && includeF8E8M0)
- return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
- if(includeBf16)
- return !(inETy.isBF16() && outETy.isF32());
- return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
- });
-
- target.addDynamicallyLegalOp<arith::TruncFOp>(
- [=](arith::TruncFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- if(includeBf16 && includeF8E8M0)
- return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
- if(includeBf16)
- return !(inETy.isF32() && outETy.isBF16());
- return
- !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
- });
- }
+
+ target.addDynamicallyLegalOp<arith::ExtFOp>(
+ [=](arith::ExtFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if(includeBf16)
+ legalTypes &= !(inETy.isBF16() && outETy.isF32());
+ if(includeF8E8M0)
+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+ return legalTypes;
+ });
+
+ target.addDynamicallyLegalOp<arith::TruncFOp>(
+ [=](arith::TruncFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if(includeBf16)
+ legalTypes &= !(inETy.isF32() && outETy.isBF16());
+ if(includeF8E8M0)
+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
+ return legalTypes;
+ });
+
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 975b26ae4369f..ab6f5eda1ad7d 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,6 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
>From 878e69a000877e1137fb17ef5a18d4af10c65425 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 22 May 2025 12:55:53 +0000
Subject: [PATCH 4/6] Add missing spaces for If conditions
---
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 762cf91092f86..09bb6c5ef72c8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -458,9 +458,9 @@ struct ArithExpandOpsPass
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if(includeBf16)
+ if (includeBf16)
legalTypes &= !(inETy.isBF16() && outETy.isF32());
- if(includeF8E8M0)
+ if (includeF8E8M0)
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
return legalTypes;
});
@@ -470,9 +470,9 @@ struct ArithExpandOpsPass
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
- if(includeBf16)
+ if (includeBf16)
legalTypes &= !(inETy.isF32() && outETy.isBF16());
- if(includeF8E8M0)
+ if (includeF8E8M0)
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
return legalTypes;
});
>From f72e9537ccedb5a4ac067210227a2a4db87c96c1 Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 22 May 2025 13:07:14 +0000
Subject: [PATCH 5/6] add cloneToShapedType static method
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 49 +++++++------------
1 file changed, 19 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 09bb6c5ef72c8..95546bb09e765 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Creates shapedType using shape from cloneFrom and base type from cloneTo
+static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
+ return shapedTy.clone(cloneTo);
+ }
+ return cloneTo;
+}
+
namespace {
/// Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
}
- Type i16Ty = b.getI16Type();
- Type i32Ty = b.getI32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i16Ty = shapedTy.clone(i16Ty);
- i32Ty = shapedTy.clone(i32Ty);
- }
+ Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}
- Type i16Ty = b.getI16Type();
- Type i32Ty = b.getI32Type();
- Type f32Ty = b.getF32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i16Ty = shapedTy.clone(i16Ty);
- i32Ty = shapedTy.clone(i32Ty);
- f32Ty = shapedTy.clone(f32Ty);
- }
+ Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
// Algorithm borrowed from this excellent code:
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -340,14 +338,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}
- Type i8Ty = b.getI8Type();
- Type i32Ty = b.getI32Type();
- Type f32Ty = b.getF32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i8Ty = shapedTy.clone(i8Ty);
- i32Ty = shapedTy.clone(i32Ty);
- f32Ty = shapedTy.clone(f32Ty);
- }
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
// create constants for NaNs
@@ -397,14 +390,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}
- Type i8Ty = b.getI8Type();
- Type i32Ty = b.getI32Type();
- Type f32Ty = b.getF32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i8Ty = shapedTy.clone(i8Ty);
- i32Ty = shapedTy.clone(i32Ty);
- f32Ty = shapedTy.clone(f32Ty);
- }
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
>From f24f8d2913082a544e2bfa6039fda6b9f06a670d Mon Sep 17 00:00:00 2001
From: Umang Yadav <umayadav at amd.com>
Date: Thu, 22 May 2025 13:17:45 +0000
Subject: [PATCH 6/6] undo unrelated change
---
mlir/lib/IR/Types.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index ab6f5eda1ad7d..765b787d3d17a 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,6 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
+
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
More information about the Mlir-commits
mailing list