[Mlir-commits] [mlir] 1fd1f65 - [mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments (#108815)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 27 08:02:20 PDT 2024


Author: Daniel Hernandez-Juarez
Date: 2024-09-27T10:02:16-05:00
New Revision: 1fd1f65569f565b5b06fd9464b3b91fcd6f2fa2a

URL: https://github.com/llvm/llvm-project/commit/1fd1f65569f565b5b06fd9464b3b91fcd6f2fa2a
DIFF: https://github.com/llvm/llvm-project/commit/1fd1f65569f565b5b06fd9464b3b91fcd6f2fa2a.diff

LOG:  [mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments (#108815)

Instead of hardcoding all fp smaller than 32 bits are unsupported we
provide a way to pass supported floating point types as well as the
target type. fp64 and fp32 are implicitly supported.

CC: @krzysz00 @manupak

Added: 
    mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
    mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
    mlir/test/Dialect/Math/extend-to-supported-types.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/Utils/Utils.h
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/include/mlir/Dialect/Math/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
    mlir/lib/Dialect/Arith/Utils/Utils.cpp
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
    mlir/test/Dialect/Math/legalize-to-f32.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..d759299cbf7625 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -130,6 +130,10 @@ namespace arith {
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
                     Type resultType);
+
+// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+
 } // namespace arith
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 2dd7f6431f03e1..2974bb344ad965 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns(
 void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
 
 namespace math {
-void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
-void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
-                                           TypeConverter &typeConverter);
-void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
-                                   TypeConverter &typeConverter);
+void populateExtendToSupportedTypesTypeConverter(
+    TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType);
+void populateExtendToSupportedTypesConversionTarget(
+    ConversionTarget &target, TypeConverter &typeConverter);
+void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
+                                            TypeConverter &typeConverter);
 } // namespace math
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda58..a84c89020d4f36 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
   let dependentDialects = ["math::MathDialect"];
 }
 
-def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
   let summary = "Legalize floating-point math ops on low-precision floats";
   let description = [{
     On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
 
     This pass explicitly legalizes these math functions by inserting
     `arith.extf` and `arith.truncf` pairs around said op, which preserves
-    the original semantics while enabling lowering.
+    the original semantics while enabling lowering. The extra supported floating-point
+    types for the target are passed as arguments. Types f64 and f32 are implicitly 
+    supported.
 
     As an exception, this pass does not legalize `math.fma`, because
     that is an operation frequently implemented at low precisions.
   }];
+  let options = [
+    ListOption<"extraTypeStrs", "extra-types", "std::string",
+      "MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
+    Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
+      "MLIR type to convert the unsupported source types to">,
+  ];
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 0bf8c8942885e6..b51444e884aaee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
@@ -49,30 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
 };
 } // end namespace
 
-/// Map strings to float types. This function is here because no one else needs
-/// it yet, feel free to abstract it out.
-static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
-                                               StringRef name) {
-  Builder b(ctx);
-  return llvm::StringSwitch<std::optional<FloatType>>(name)
-      .Case("f4E2M1FN", b.getFloat4E2M1FNType())
-      .Case("f6E2M3FN", b.getFloat6E2M3FNType())
-      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
-      .Case("f8E5M2", b.getFloat8E5M2Type())
-      .Case("f8E4M3", b.getFloat8E4M3Type())
-      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
-      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
-      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
-      .Case("f8E3M4", b.getFloat8E3M4Type())
-      .Case("bf16", b.getBF16Type())
-      .Case("f16", b.getF16Type())
-      .Case("f32", b.getF32Type())
-      .Case("f64", b.getF64Type())
-      .Case("f80", b.getF80Type())
-      .Case("f128", b.getF128Type())
-      .Default(std::nullopt);
-}
-
 LogicalResult EmulateFloatPattern::match(Operation *op) const {
   if (getTypeConverter()->isLegal(op))
     return failure();
@@ -156,7 +133,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   SmallVector<Type> sourceTypes;
   Type targetType;
 
-  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
+  std::optional<FloatType> maybeTargetType =
+      arith::parseFloatType(ctx, targetTypeStr);
   if (!maybeTargetType) {
     emitError(UnknownLoc::get(ctx), "could not map target type '" +
                                         targetTypeStr +
@@ -166,7 +144,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   targetType = *maybeTargetType;
   for (StringRef sourceTypeStr : sourceTypeStrs) {
     std::optional<FloatType> maybeSourceType =
-        parseFloatType(ctx, sourceTypeStr);
+        arith::parseFloatType(ctx, sourceTypeStr);
     if (!maybeSourceType) {
       emitError(UnknownLoc::get(ctx), "could not map source type '" +
                                           sourceTypeStr +

diff  --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..c0aa16cc0da407 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -357,4 +357,26 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
       [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
 }
 
+/// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
+  Builder b(ctx);
+  return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f4E2M1FN", b.getFloat4E2M1FNType())
+      .Case("f6E2M3FN", b.getFloat6E2M3FNType())
+      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
+      .Case("f8E5M2", b.getFloat8E5M2Type())
+      .Case("f8E4M3", b.getFloat8E4M3Type())
+      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
+      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
+      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("f8E3M4", b.getFloat8E3M4Type())
+      .Case("bf16", b.getBF16Type())
+      .Case("f16", b.getF16Type())
+      .Case("f32", b.getF32Type())
+      .Case("f64", b.getF64Type())
+      .Case("f80", b.getF80Type())
+      .Case("f128", b.getF128Type())
+      .Default(std::nullopt);
+}
+
 } // namespace mlir::arith

diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb52712..e1c0c2410c1269 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
-  LegalizeToF32.cpp
+  ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
new file mode 100644
index 00000000000000..1a9eafec9fdd57
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -0,0 +1,164 @@
+//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on unsupported floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
+  ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
+                                       MLIRContext *context)
+      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct ExtendToSupportedTypesPass
+    : mlir::math::impl::MathExtendToSupportedTypesBase<
+          ExtendToSupportedTypesPass> {
+  using math::impl::MathExtendToSupportedTypesBase<
+      ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateExtendToSupportedTypesTypeConverter(
+    TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType) {
+
+  typeConverter.addConversion(
+      [](Type type) -> std::optional<Type> { return type; });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
+        if (!sourceTypes.contains(type))
+          return targetType;
+
+        return std::nullopt;
+      });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
+        if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+          if (!sourceTypes.contains(elemTy))
+            return type.clone(targetType);
+
+        return std::nullopt;
+      });
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp.setFastmath(arith::FastMathFlags::contract);
+        return extFOp;
+      });
+}
+
+void mlir::math::populateExtendToSupportedTypesConversionTarget(
+    ConversionTarget &target, TypeConverter &typeConverter) {
+  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+    if (isa<MathDialect>(op->getDialect()))
+      return typeConverter.isLegal(op);
+    return true;
+  });
+  target.addLegalOp<FmaOp>();
+  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op->getLoc();
+  const TypeConverter *converter = getTypeConverter();
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
+
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
+    if (newType != origType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+      truncFOp.setFastmath(arith::FastMathFlags::contract);
+      result = truncFOp.getResult();
+    }
+  }
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+void mlir::math::populateExtendToSupportedTypesPatterns(
+    RewritePatternSet &patterns, TypeConverter &typeConverter) {
+  patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
+                                                     patterns.getContext());
+}
+
+void ExtendToSupportedTypesPass::runOnOperation() {
+  Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
+
+  // Parse target type
+  std::optional<Type> maybeTargetType =
+      arith::parseFloatType(ctx, targetTypeStr);
+  if (!maybeTargetType.has_value()) {
+    emitError(UnknownLoc::get(ctx), "could not map target type '" +
+                                        targetTypeStr +
+                                        "' to a known floating-point type");
+    return signalPassFailure();
+  }
+  Type targetType = maybeTargetType.value();
+
+  // Parse source types
+  llvm::SetVector<Type> sourceTypes;
+  for (const auto &extraTypeStr : extraTypeStrs) {
+    std::optional<FloatType> maybeExtraType =
+        arith::parseFloatType(ctx, extraTypeStr);
+    if (!maybeExtraType.has_value()) {
+      emitError(UnknownLoc::get(ctx), "could not map source type '" +
+                                          extraTypeStr +
+                                          "' to a known floating-point type");
+      return signalPassFailure();
+    }
+    sourceTypes.insert(maybeExtraType.value());
+  }
+  // f64 and f32 are implicitly supported
+  Builder b(ctx);
+  sourceTypes.insert(b.getF64Type());
+  sourceTypes.insert(b.getF32Type());
+
+  TypeConverter typeConverter;
+  math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
+                                                    targetType);
+  ConversionTarget target(*ctx);
+  math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
+  RewritePatternSet patterns(ctx);
+  math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    return signalPassFailure();
+}

diff  --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
deleted file mode 100644
index 2e60fe455dcade..00000000000000
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements legalizing math operations on small floating-point
-// types through arith.extf and arith.truncf.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHLEGALIZETOF32
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-namespace {
-struct LegalizeToF32RewritePattern final : ConversionPattern {
-  LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
-      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-struct LegalizeToF32Pass final
-    : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
-  void runOnOperation() override;
-};
-} // namespace
-
-void mlir::math::populateLegalizeToF32TypeConverter(
-    TypeConverter &typeConverter) {
-  typeConverter.addConversion(
-      [](Type type) -> std::optional<Type> { return type; });
-  typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
-    if (type.getWidth() < 32)
-      return Float32Type::get(type.getContext());
-    return std::nullopt;
-  });
-  typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
-    if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
-      return type.clone(Float32Type::get(type.getContext()));
-    return std::nullopt;
-  });
-  typeConverter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
-}
-
-void mlir::math::populateLegalizeToF32ConversionTarget(
-    ConversionTarget &target, TypeConverter &typeConverter) {
-  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
-    if (isa<MathDialect>(op->getDialect()))
-      return typeConverter.isLegal(op);
-    return true;
-  });
-  target.addLegalOp<FmaOp>();
-  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
-}
-
-LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  Location loc = op->getLoc();
-  const TypeConverter *converter = getTypeConverter();
-  FailureOr<Operation *> legalized =
-      convertOpResultTypes(op, operands, *converter, rewriter);
-  if (failed(legalized))
-    return failure();
-
-  SmallVector<Value> results = (*legalized)->getResults();
-  for (auto [result, newType, origType] : llvm::zip_equal(
-           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      result = truncFOp.getResult();
-    }
-  }
-  rewriter.replaceOp(op, results);
-  return success();
-}
-
-void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
-                                               TypeConverter &typeConverter) {
-  patterns.add<LegalizeToF32RewritePattern>(typeConverter,
-                                            patterns.getContext());
-}
-
-void LegalizeToF32Pass::runOnOperation() {
-  Operation *op = getOperation();
-  MLIRContext &ctx = getContext();
-
-  TypeConverter typeConverter;
-  math::populateLegalizeToF32TypeConverter(typeConverter);
-  ConversionTarget target(ctx);
-  math::populateLegalizeToF32ConversionTarget(target, typeConverter);
-  RewritePatternSet patterns(&ctx);
-  math::populateLegalizeToF32Patterns(patterns, typeConverter);
-  if (failed(applyPartialConversion(op, target, std::move(patterns))))
-    return signalPassFailure();
-}

diff  --git a/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
new file mode 100644
index 00000000000000..3674a91ef425f8
--- /dev/null
+++ b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="extra-types=f16 target-type=f32" | FileCheck %s
+
+// CHECK-LABEL: @sin_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+func.func @sin_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+  // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+  // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+  // CHECK: return [[TRUNCF]] : f8E5M2
+  %0 = math.sin %arg0 : f8E5M2
+  return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+  // CHECK16: [[SIN:%.+]] = math.sin [[ARG0]] : f16
+  // CHECK16: return [[SIN]] : f16
+  %0 = math.sin %arg0 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2, [[ARG1:%.+]]: i32)
+func.func @fpowi_f8E5M2(%arg0: f8E5M2, %arg1: i32) -> f8E5M2 {
+  // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
+  // CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
+  // CHECK: return [[TRUNCF]] : f8E5M2
+  %0 = math.fpowi %arg0, %arg1 : f8E5M2, i32
+  return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @fpowi
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
+func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
+  // CHECK: [[FPOWI:%.+]] = math.fpowi [[ARG0]], [[ARG1]]
+  // CHECK: return [[FPOWI]] : f16
+  %0 = math.fpowi %arg0, %arg1 : f16, i32
+  return %0 : f16
+}
+
+// COM: Verify that the pass leaves `math.fma` untouched, since it is often
+// COM: implemented on small data types.
+// CHECK-LABEL: @fma
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
+// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK: return [[FMA]] : f16
+func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+  %0 = math.fma %arg0, %arg1, %arg2 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f16
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f16
+func.func @absf_f16(%arg0: f16) -> f16 {
+  %0 = math.absf %arg0 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f32
+// CHECK-SAME: ([[ARG0:%.+]]: f32)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f32
+func.func @absf_f32(%arg0: f32) -> f32 {
+  %0 = math.absf %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @absf_f64
+// CHECK-SAME: ([[ARG0:%.+]]: f64)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f64
+func.func @absf_f64(%arg0: f64) -> f64 {
+  %0 = math.absf %arg0 : f64
+  return %0 : f64
+}
+
+// CHECK-LABEL: @sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<2xbf16>
+func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
+  %0 = math.sin %arg0 : vector<2xbf16>
+  return %0 : vector<2xbf16>
+}
+
+// CHECK-LABEL: @sin_vector_f16
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xf16>)
+// CHECK: [[SIN:%.+]] = math.sin [[ARG0]]
+// CHECK: return [[SIN]] : vector<2xf16>
+func.func @sin_vector_f16(%arg0: vector<2xf16>) -> vector<2xf16> {
+  %0 = math.sin %arg0 : vector<2xf16>
+  return %0 : vector<2xf16>
+}
+
+// CHECK-LABEL: @fastmath
+// CHECK: math.sin %{{.+}} fastmath<nsz>
+func.func @fastmath(%arg0: f16) -> f16 {
+  %0 = math.sin %arg0 fastmath<nsz> : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @sequences_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF1]] : f8E5M2
+func.func @sequences_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+  %0 = math.absf %arg0 : f8E5M2
+  %1 = math.sin %0 : f8E5M2
+  return %1 : f8E5M2
+}
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: return [[SIN]] : f16
+func.func @sequences(%arg0: f16) -> f16 {
+  %0 = math.absf %arg0 : f16
+  %1 = math.sin %0 : f16
+  return %1 : f16
+}
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+  // CHECK: [[EXTF0:%.+]] = arith.extf
+  // CHECK-NEXT: %[[RES:.*]] = scf.if
+  %0 = scf.if %arg2 -> bf16 {
+    %1 = math.absf %arg0 : bf16
+    // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+    scf.yield %1 : bf16
+  } else {
+    scf.yield %arg1 : bf16
+  }
+  return %0 : bf16
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/extend-to-supported-types.mlir
similarity index 96%
rename from mlir/test/Dialect/Math/legalize-to-f32.mlir
rename to mlir/test/Dialect/Math/extend-to-supported-types.mlir
index ebb0de9d2653e2..ad7169d4cf4ae5 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/extend-to-supported-types.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="target-type=f32" | FileCheck %s
 
 // CHECK-LABEL: @sin
 // CHECK-SAME: ([[ARG0:%.+]]: f16)


        


More information about the Mlir-commits mailing list