[Mlir-commits] [mlir] ffa5ce0 - Add arith expansion of f8E8M0 type for extf/trunc ops (#140332)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 22 13:36:03 PDT 2025
Author: Umang Yadav
Date: 2025-05-22T15:36:00-05:00
New Revision: ffa5ce04d0a5440f939881e0e329a173d486dd68
URL: https://github.com/llvm/llvm-project/commit/ffa5ce04d0a5440f939881e0e329a173d486dd68
DIFF: https://github.com/llvm/llvm-project/commit/ffa5ce04d0a5440f939881e0e329a173d486dd68.diff
LOG: Add arith expansion of f8E8M0 type for extf/trunc ops (#140332)
F8E8M0 floating type is supposed to represent biased exponent bits of
F32 type in OCP Micro scaling floating point formats.
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
This PR expands `arith.truncf` and `arith.extf` to support this
behavior.
For the `arith.truncf` thing to note here is that F8E8M0FNU type has one
NaN representation which is encoded as `0xFF`. Therefore alll kinds of
NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU
doesn't have a sign bit therefore it is a lossy and irreversible
downcast.
cc: @krzysz00 @MaheshRavishankar @Muzammiluddin-Syed-ECE
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arith/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
- let options = [
- Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
- "Enable the BF16 expansion patterns">,
+ let options =
+ [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+ "Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..95546bb09e765 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Creates shapedType using shape from cloneFrom and base type from cloneTo
+static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
+ return shapedTy.clone(cloneTo);
+ }
+ return cloneTo;
+}
+
namespace {
/// Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
}
- Type i16Ty = b.getI16Type();
- Type i32Ty = b.getI32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i16Ty = shapedTy.clone(i16Ty);
- i32Ty = shapedTy.clone(i32Ty);
- }
+ Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}
- Type i16Ty = b.getI16Type();
- Type i32Ty = b.getI32Type();
- Type f32Ty = b.getF32Type();
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i16Ty = shapedTy.clone(i16Ty);
- i32Ty = shapedTy.clone(i32Ty);
- f32Ty = shapedTy.clone(f32Ty);
- }
+ Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
// Algorithm borrowed from this excellent code:
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
- Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+ Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ if (operandETy.getIntOrFloatBitWidth() < 32) {
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ } else if (operandETy.getIntOrFloatBitWidth() > 32) {
+ operand = b.create<arith::TruncFOp>(f32Ty, operand);
+ }
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
if (includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
- target.addDynamicallyLegalOp<arith::ExtFOp>(
- [](arith::ExtFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isBF16() && outETy.isF32());
- });
-
- target.addDynamicallyLegalOp<arith::TruncFOp>(
- [](arith::TruncFOp op) {
- Type inETy = getElementTypeOrSelf(op.getOperand().getType());
- Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isF32() && outETy.isBF16());
- });
}
+ if (includeF8E8M0) {
+ arith::populateExpandF8E8M0Patterns(patterns);
+ }
+
+ target.addDynamicallyLegalOp<arith::ExtFOp>(
+ [=](arith::ExtFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if (includeBf16)
+ legalTypes &= !(inETy.isBF16() && outETy.isF32());
+ if (includeF8E8M0)
+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
+ return legalTypes;
+ });
+
+ target.addDynamicallyLegalOp<arith::TruncFOp>(
+ [=](arith::TruncFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ bool legalTypes = true;
+ if (includeBf16)
+ legalTypes &= !(inETy.isF32() && outETy.isBF16());
+ if (includeF8E8M0)
+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
+ return legalTypes;
+ });
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+ patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f16
+ return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
// -----
func.func @maxsi(%a: i32, %b: i32) -> i32 {
More information about the Mlir-commits
mailing list