[Mlir-commits] [mlir] [mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments (PR #108815)
Daniel Hernandez-Juarez
llvmlistbot at llvm.org
Tue Sep 17 02:34:29 PDT 2024
https://github.com/dhernandez0 updated https://github.com/llvm/llvm-project/pull/108815
>From 0d806f8c6db4aa0f5ec45e0f78ff96bd194663e4 Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Mon, 16 Sep 2024 10:18:00 +0000
Subject: [PATCH 1/2] Refactor LegalizeToF32 to specify extra supported float
types and target type as arguments
Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type.
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 4 +
.../mlir/Dialect/Math/Transforms/Passes.h | 12 +-
.../mlir/Dialect/Math/Transforms/Passes.td | 12 +-
.../Transforms/EmulateUnsupportedFloats.cpp | 6 +-
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 22 +++
.../Dialect/Math/Transforms/CMakeLists.txt | 2 +-
.../Transforms/ExtendToSupportedTypes.cpp | 164 ++++++++++++++++++
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 118 -------------
.../Math/extend-to-supported-types-f16.mlir | 137 +++++++++++++++
...32.mlir => extend-to-supported-types.mlir} | 2 +-
10 files changed, 350 insertions(+), 129 deletions(-)
create mode 100644 mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
delete mode 100644 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
create mode 100644 mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
rename mlir/test/Dialect/Math/{legalize-to-f32.mlir => extend-to-supported-types.mlir} (96%)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..d759299cbf7625 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -130,6 +130,10 @@ namespace arith {
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType);
+
+// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 2dd7f6431f03e1..2974bb344ad965 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns(
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
namespace math {
-void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
-void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
- TypeConverter &typeConverter);
-void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
- TypeConverter &typeConverter);
+void populateExtendToSupportedTypesTypeConverter(
+ TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+ Type targetType);
+void populateExtendToSupportedTypesConversionTarget(
+ ConversionTarget &target, TypeConverter &typeConverter);
+void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter);
} // namespace math
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda58..a84c89020d4f36 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let dependentDialects = ["math::MathDialect"];
}
-def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
let summary = "Legalize floating-point math ops on low-precision floats";
let description = [{
On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
This pass explicitly legalizes these math functions by inserting
`arith.extf` and `arith.truncf` pairs around said op, which preserves
- the original semantics while enabling lowering.
+ the original semantics while enabling lowering. The extra supported floating-point
+ types for the target are passed as arguments. Types f64 and f32 are implicitly
+ supported.
As an exception, this pass does not legalize `math.fma`, because
that is an operation frequently implemented at low precisions.
}];
+ let options = [
+ ListOption<"extraTypeStrs", "extra-types", "std::string",
+ "MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
+ Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
+ "MLIR type to convert the unsupported source types to">,
+ ];
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 5e5e10b1fa1c2b..67a3a56d555ef2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
@@ -155,7 +156,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
SmallVector<Type> sourceTypes;
Type targetType;
- std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
+ std::optional<FloatType> maybeTargetType =
+ arith::parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
@@ -165,7 +167,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
targetType = *maybeTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
std::optional<FloatType> maybeSourceType =
- parseFloatType(ctx, sourceTypeStr);
+ arith::parseFloatType(ctx, sourceTypeStr);
if (!maybeSourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..0c4da0f728074d 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -357,4 +357,26 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
}
+/// Map strings to float types.
+static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
+ StringRef name) {
+ Builder b(ctx);
+ return llvm::StringSwitch<std::optional<FloatType>>(name)
+ .Case("f6E2M3FN", b.getFloat6E2M3FNType())
+ .Case("f6E3M2FN", b.getFloat6E3M2FNType())
+ .Case("f8E5M2", b.getFloat8E5M2Type())
+ .Case("f8E4M3", b.getFloat8E4M3Type())
+ .Case("f8E4M3FN", b.getFloat8E4M3FNType())
+ .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
+ .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+ .Case("f8E3M4", b.getFloat8E3M4Type())
+ .Case("bf16", b.getBF16Type())
+ .Case("f16", b.getF16Type())
+ .Case("f32", b.getF32Type())
+ .Case("f64", b.getF64Type())
+ .Case("f80", b.getF80Type())
+ .Case("f128", b.getF128Type())
+ .Default(std::nullopt);
+}
+
} // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb52712..e1c0c2410c1269 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
- LegalizeToF32.cpp
+ ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
new file mode 100644
index 00000000000000..fff2a962333017
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -0,0 +1,164 @@
+//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on unsupported floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
+ ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
+ MLIRContext *context)
+ : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct ExtendToSupportedTypesPass
+ : mlir::math::impl::MathExtendToSupportedTypesBase<
+ ExtendToSupportedTypesPass> {
+ using math::impl::MathExtendToSupportedTypesBase<
+ ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateExtendToSupportedTypesTypeConverter(
+ TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+ Type targetType) {
+
+ typeConverter.addConversion(
+ [](Type type) -> std::optional<Type> { return type; });
+ typeConverter.addConversion(
+ [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
+ if (!sourceTypes.contains(type))
+ return targetType;
+
+ return std::nullopt;
+ });
+ typeConverter.addConversion(
+ [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
+ if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+ if (!sourceTypes.contains(type))
+ return type.clone(targetType);
+
+ return std::nullopt;
+ });
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
+ });
+}
+
+void mlir::math::populateExtendToSupportedTypesConversionTarget(
+ ConversionTarget &target, TypeConverter &typeConverter) {
+ target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+ if (isa<MathDialect>(op->getDialect()))
+ return typeConverter.isLegal(op);
+ return true;
+ });
+ target.addLegalOp<FmaOp>();
+ target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ const TypeConverter *converter = getTypeConverter();
+ FailureOr<Operation *> legalized =
+ convertOpResultTypes(op, operands, *converter, rewriter);
+ if (failed(legalized))
+ return failure();
+
+ SmallVector<Value> results = (*legalized)->getResults();
+ for (auto [result, newType, origType] : llvm::zip_equal(
+ results, (*legalized)->getResultTypes(), op->getResultTypes())) {
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ result = truncFOp.getResult();
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+void mlir::math::populateExtendToSupportedTypesPatterns(
+ RewritePatternSet &patterns, TypeConverter &typeConverter) {
+ patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
+ patterns.getContext());
+}
+
+void ExtendToSupportedTypesPass::runOnOperation() {
+ Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
+
+ // Parse target type
+ std::optional<Type> maybeTargetType =
+ arith::parseFloatType(ctx, targetTypeStr);
+ if (!maybeTargetType.has_value()) {
+ emitError(UnknownLoc::get(ctx), "could not map target type '" +
+ targetTypeStr +
+ "' to a known floating-point type");
+ return signalPassFailure();
+ }
+ Type targetType = maybeTargetType.value();
+
+ // Parse source types
+ llvm::SetVector<Type> sourceTypes;
+ for (const auto &extraTypeStr : extraTypeStrs) {
+ std::optional<FloatType> maybeExtraType =
+ arith::parseFloatType(ctx, extraTypeStr);
+ if (!maybeExtraType.has_value()) {
+ emitError(UnknownLoc::get(ctx), "could not map source type '" +
+ extraTypeStr +
+ "' to a known floating-point type");
+ return signalPassFailure();
+ }
+ sourceTypes.insert(maybeExtraType.value());
+ }
+ // f64 and f32 are implicitly supported
+ Builder b(ctx);
+ sourceTypes.insert(b.getF64Type());
+ sourceTypes.insert(b.getF32Type());
+
+ TypeConverter typeConverter;
+ math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
+ targetType);
+ ConversionTarget target(*ctx);
+ math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
+ RewritePatternSet patterns(ctx);
+ math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ return signalPassFailure();
+}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
deleted file mode 100644
index 2e60fe455dcade..00000000000000
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements legalizing math operations on small floating-point
-// types through arith.extf and arith.truncf.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHLEGALIZETOF32
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-namespace {
-struct LegalizeToF32RewritePattern final : ConversionPattern {
- LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
- : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-struct LegalizeToF32Pass final
- : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
- void runOnOperation() override;
-};
-} // namespace
-
-void mlir::math::populateLegalizeToF32TypeConverter(
- TypeConverter &typeConverter) {
- typeConverter.addConversion(
- [](Type type) -> std::optional<Type> { return type; });
- typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
- if (type.getWidth() < 32)
- return Float32Type::get(type.getContext());
- return std::nullopt;
- });
- typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
- if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
- return type.clone(Float32Type::get(type.getContext()));
- return std::nullopt;
- });
- typeConverter.addTargetMaterialization(
- [](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
- });
-}
-
-void mlir::math::populateLegalizeToF32ConversionTarget(
- ConversionTarget &target, TypeConverter &typeConverter) {
- target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
- if (isa<MathDialect>(op->getDialect()))
- return typeConverter.isLegal(op);
- return true;
- });
- target.addLegalOp<FmaOp>();
- target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
-}
-
-LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- Location loc = op->getLoc();
- const TypeConverter *converter = getTypeConverter();
- FailureOr<Operation *> legalized =
- convertOpResultTypes(op, operands, *converter, rewriter);
- if (failed(legalized))
- return failure();
-
- SmallVector<Value> results = (*legalized)->getResults();
- for (auto [result, newType, origType] : llvm::zip_equal(
- results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- result = truncFOp.getResult();
- }
- }
- rewriter.replaceOp(op, results);
- return success();
-}
-
-void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
- TypeConverter &typeConverter) {
- patterns.add<LegalizeToF32RewritePattern>(typeConverter,
- patterns.getContext());
-}
-
-void LegalizeToF32Pass::runOnOperation() {
- Operation *op = getOperation();
- MLIRContext &ctx = getContext();
-
- TypeConverter typeConverter;
- math::populateLegalizeToF32TypeConverter(typeConverter);
- ConversionTarget target(ctx);
- math::populateLegalizeToF32ConversionTarget(target, typeConverter);
- RewritePatternSet patterns(&ctx);
- math::populateLegalizeToF32Patterns(patterns, typeConverter);
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
- return signalPassFailure();
-}
diff --git a/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
new file mode 100644
index 00000000000000..6a61b3ba6251fa
--- /dev/null
+++ b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="extra-types=f16 target-type=f32" | FileCheck %s
+
+// CHECK-LABEL: @sin_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+func.func @sin_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+ // CHECK: return [[TRUNCF]] : f8E5M2
+ %0 = math.sin %arg0 : f8E5M2
+ return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+ // CHECK16: [[SIN:%.+]] = math.sin [[ARG0]] : f16
+ // CHECK16: return [[SIN]] : f16
+ %0 = math.sin %arg0 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2, [[ARG1:%.+]]: i32)
+func.func @fpowi_f8E5M2(%arg0: f8E5M2, %arg1: i32) -> f8E5M2 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
+ // CHECK: return [[TRUNCF]] : f8E5M2
+ %0 = math.fpowi %arg0, %arg1 : f8E5M2, i32
+ return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @fpowi
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
+func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
+ // CHECK: [[FPOWI:%.+]] = math.fpowi [[ARG0]], [[ARG1]]
+ // CHECK: return [[FPOWI]] : f16
+ %0 = math.fpowi %arg0, %arg1 : f16, i32
+ return %0 : f16
+}
+
+// COM: Verify that the pass leaves `math.fma` untouched, since it is often
+// COM: implemented on small data types.
+// CHECK-LABEL: @fma
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
+// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK: return [[FMA]] : f16
+func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+ %0 = math.fma %arg0, %arg1, %arg2 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f16
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f16
+func.func @absf_f16(%arg0: f16) -> f16 {
+ %0 = math.absf %arg0 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f32
+// CHECK-SAME: ([[ARG0:%.+]]: f32)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f32
+func.func @absf_f32(%arg0: f32) -> f32 {
+ %0 = math.absf %arg0 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @absf_f64
+// CHECK-SAME: ([[ARG0:%.+]]: f64)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f64
+func.func @absf_f64(%arg0: f64) -> f64 {
+ %0 = math.absf %arg0 : f64
+ return %0 : f64
+}
+
+// CHECK-LABEL: @sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<2xbf16>
+func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
+ %0 = math.sin %arg0 : vector<2xbf16>
+ return %0 : vector<2xbf16>
+}
+
+// CHECK-LABEL: @fastmath
+// CHECK: math.sin %{{.+}} fastmath<nsz>
+func.func @fastmath(%arg0: f16) -> f16 {
+ %0 = math.sin %arg0 fastmath<nsz> : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @sequences_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF1]] : f8E5M2
+func.func @sequences_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+ %0 = math.absf %arg0 : f8E5M2
+ %1 = math.sin %0 : f8E5M2
+ return %1 : f8E5M2
+}
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: return [[SIN]] : f16
+func.func @sequences(%arg0: f16) -> f16 {
+ %0 = math.absf %arg0 : f16
+ %1 = math.sin %0 : f16
+ return %1 : f16
+}
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+ // CHECK: [[EXTF0:%.+]] = arith.extf
+ // CHECK-NEXT: %[[RES:.*]] = scf.if
+ %0 = scf.if %arg2 -> bf16 {
+ %1 = math.absf %arg0 : bf16
+ // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+ scf.yield %1 : bf16
+ } else {
+ scf.yield %arg1 : bf16
+ }
+ return %0 : bf16
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/extend-to-supported-types.mlir
similarity index 96%
rename from mlir/test/Dialect/Math/legalize-to-f32.mlir
rename to mlir/test/Dialect/Math/extend-to-supported-types.mlir
index ebb0de9d2653e2..ad7169d4cf4ae5 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/extend-to-supported-types.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="target-type=f32" | FileCheck %s
// CHECK-LABEL: @sin
// CHECK-SAME: ([[ARG0:%.+]]: f16)
>From 4bd7966ded038176c47e8c7a05f55882e4be02c8 Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Tue, 17 Sep 2024 09:34:13 +0000
Subject: [PATCH 2/2] Fix parseFloatType
---
.../Transforms/EmulateUnsupportedFloats.cpp | 23 -------------------
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 3 +--
2 files changed, 1 insertion(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 67a3a56d555ef2..b51444e884aaee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -50,29 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
};
} // end namespace
-/// Map strings to float types. This function is here because no one else needs
-/// it yet, feel free to abstract it out.
-static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
- StringRef name) {
- Builder b(ctx);
- return llvm::StringSwitch<std::optional<FloatType>>(name)
- .Case("f6E2M3FN", b.getFloat6E2M3FNType())
- .Case("f6E3M2FN", b.getFloat6E3M2FNType())
- .Case("f8E5M2", b.getFloat8E5M2Type())
- .Case("f8E4M3", b.getFloat8E4M3Type())
- .Case("f8E4M3FN", b.getFloat8E4M3FNType())
- .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
- .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
- .Case("f8E3M4", b.getFloat8E3M4Type())
- .Case("bf16", b.getBF16Type())
- .Case("f16", b.getF16Type())
- .Case("f32", b.getF32Type())
- .Case("f64", b.getF64Type())
- .Case("f80", b.getF80Type())
- .Case("f128", b.getF128Type())
- .Default(std::nullopt);
-}
-
LogicalResult EmulateFloatPattern::match(Operation *op) const {
if (getTypeConverter()->isLegal(op))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0c4da0f728074d..4874b66ea58906 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -358,8 +358,7 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
}
/// Map strings to float types.
-static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
- StringRef name) {
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
More information about the Mlir-commits
mailing list