[Mlir-commits] [mlir] [mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments (PR #108815)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 16 03:29:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir-arith

Author: Daniel Hernandez-Juarez (dhernandez0)

<details>
<summary>Changes</summary>

Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type. fp64 and fp32 are implicitly supported.

---

Patch is 24.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108815.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/Utils/Utils.h (+4) 
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+4-3) 
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.td (+10-2) 
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-24) 
- (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+19) 
- (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1-1) 
- (added) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+164) 
- (removed) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (-118) 
- (added) mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir (+137) 
- (renamed) mlir/test/Dialect/Math/extend-to-supported-types.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..81da2d208d7269 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -130,6 +130,10 @@ namespace arith {
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
                     Type resultType);
+                    
+// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 2dd7f6431f03e1..b6ee5bc93ff694 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -56,10 +56,11 @@ void populateMathPolynomialApproximationPatterns(
 void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
 
 namespace math {
-void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
-void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
+void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType);
+void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target,
                                            TypeConverter &typeConverter);
-void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
                                    TypeConverter &typeConverter);
 } // namespace math
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda58..a84c89020d4f36 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
   let dependentDialects = ["math::MathDialect"];
 }
 
-def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
   let summary = "Legalize floating-point math ops on low-precision floats";
   let description = [{
     On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
 
     This pass explicitly legalizes these math functions by inserting
     `arith.extf` and `arith.truncf` pairs around said op, which preserves
-    the original semantics while enabling lowering.
+    the original semantics while enabling lowering. The extra supported floating-point
+    types for the target are passed as arguments. Types f64 and f32 are implicitly 
+    supported.
 
     As an exception, this pass does not legalize `math.fma`, because
     that is an operation frequently implemented at low precisions.
   }];
+  let options = [
+    ListOption<"extraTypeStrs", "extra-types", "std::string",
+      "MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
+    Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
+      "MLIR type to convert the unsupported source types to">,
+  ];
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..eb8f74270d2ef8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
@@ -49,28 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
 };
 } // end namespace
 
-/// Map strings to float types. This function is here because no one else needs
-/// it yet, feel free to abstract it out.
-static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
-                                               StringRef name) {
-  Builder b(ctx);
-  return llvm::StringSwitch<std::optional<FloatType>>(name)
-      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
-      .Case("f8E5M2", b.getFloat8E5M2Type())
-      .Case("f8E4M3", b.getFloat8E4M3Type())
-      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
-      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
-      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
-      .Case("f8E3M4", b.getFloat8E3M4Type())
-      .Case("bf16", b.getBF16Type())
-      .Case("f16", b.getF16Type())
-      .Case("f32", b.getF32Type())
-      .Case("f64", b.getF64Type())
-      .Case("f80", b.getF80Type())
-      .Case("f128", b.getF128Type())
-      .Default(std::nullopt);
-}
-
 LogicalResult EmulateFloatPattern::match(Operation *op) const {
   if (getTypeConverter()->isLegal(op))
     return failure();
@@ -154,7 +133,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   SmallVector<Type> sourceTypes;
   Type targetType;
 
-  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
+  std::optional<FloatType> maybeTargetType = arith::parseFloatType(ctx, targetTypeStr);
   if (!maybeTargetType) {
     emitError(UnknownLoc::get(ctx), "could not map target type '" +
                                         targetTypeStr +
@@ -164,7 +143,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   targetType = *maybeTargetType;
   for (StringRef sourceTypeStr : sourceTypeStrs) {
     std::optional<FloatType> maybeSourceType =
-        parseFloatType(ctx, sourceTypeStr);
+        arith::parseFloatType(ctx, sourceTypeStr);
     if (!maybeSourceType) {
       emitError(UnknownLoc::get(ctx), "could not map source type '" +
                                           sourceTypeStr +
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..c4d11f5474a15e 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -357,4 +357,23 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
       [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
 }
 
+/// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
+  Builder b(ctx);
+
+  return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f8E5M2", b.getFloat8E5M2Type())
+      .Case("f8E4M3", b.getFloat8E4M3Type())
+      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
+      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
+      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("bf16", b.getBF16Type())
+      .Case("f16", b.getF16Type())
+      .Case("f32", b.getF32Type())
+      .Case("f64", b.getF64Type())
+      .Case("f80", b.getF80Type())
+      .Case("f128", b.getF128Type())
+      .Default(std::nullopt);
+}
+
 } // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb52712..e1c0c2410c1269 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
-  LegalizeToF32.cpp
+  ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
new file mode 100644
index 00000000000000..fff2a962333017
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -0,0 +1,164 @@
+//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on unsupported floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
+  ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
+                                       MLIRContext *context)
+      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct ExtendToSupportedTypesPass
+    : mlir::math::impl::MathExtendToSupportedTypesBase<
+          ExtendToSupportedTypesPass> {
+  using math::impl::MathExtendToSupportedTypesBase<
+      ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateExtendToSupportedTypesTypeConverter(
+    TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType) {
+
+  typeConverter.addConversion(
+      [](Type type) -> std::optional<Type> { return type; });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
+        if (!sourceTypes.contains(type))
+          return targetType;
+
+        return std::nullopt;
+      });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
+        if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+          if (!sourceTypes.contains(type))
+            return type.clone(targetType);
+
+        return std::nullopt;
+      });
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp.setFastmath(arith::FastMathFlags::contract);
+        return extFOp;
+      });
+}
+
+void mlir::math::populateExtendToSupportedTypesConversionTarget(
+    ConversionTarget &target, TypeConverter &typeConverter) {
+  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+    if (isa<MathDialect>(op->getDialect()))
+      return typeConverter.isLegal(op);
+    return true;
+  });
+  target.addLegalOp<FmaOp>();
+  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op->getLoc();
+  const TypeConverter *converter = getTypeConverter();
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
+
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
+    if (newType != origType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+      truncFOp.setFastmath(arith::FastMathFlags::contract);
+      result = truncFOp.getResult();
+    }
+  }
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+void mlir::math::populateExtendToSupportedTypesPatterns(
+    RewritePatternSet &patterns, TypeConverter &typeConverter) {
+  patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
+                                                     patterns.getContext());
+}
+
+void ExtendToSupportedTypesPass::runOnOperation() {
+  Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
+
+  // Parse target type
+  std::optional<Type> maybeTargetType =
+      arith::parseFloatType(ctx, targetTypeStr);
+  if (!maybeTargetType.has_value()) {
+    emitError(UnknownLoc::get(ctx), "could not map target type '" +
+                                        targetTypeStr +
+                                        "' to a known floating-point type");
+    return signalPassFailure();
+  }
+  Type targetType = maybeTargetType.value();
+
+  // Parse source types
+  llvm::SetVector<Type> sourceTypes;
+  for (const auto &extraTypeStr : extraTypeStrs) {
+    std::optional<FloatType> maybeExtraType =
+        arith::parseFloatType(ctx, extraTypeStr);
+    if (!maybeExtraType.has_value()) {
+      emitError(UnknownLoc::get(ctx), "could not map source type '" +
+                                          extraTypeStr +
+                                          "' to a known floating-point type");
+      return signalPassFailure();
+    }
+    sourceTypes.insert(maybeExtraType.value());
+  }
+  // f64 and f32 are implicitly supported
+  Builder b(ctx);
+  sourceTypes.insert(b.getF64Type());
+  sourceTypes.insert(b.getF32Type());
+
+  TypeConverter typeConverter;
+  math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
+                                                    targetType);
+  ConversionTarget target(*ctx);
+  math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
+  RewritePatternSet patterns(ctx);
+  math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
deleted file mode 100644
index 2e60fe455dcade..00000000000000
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements legalizing math operations on small floating-point
-// types through arith.extf and arith.truncf.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHLEGALIZETOF32
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-namespace {
-struct LegalizeToF32RewritePattern final : ConversionPattern {
-  LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
-      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-struct LegalizeToF32Pass final
-    : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
-  void runOnOperation() override;
-};
-} // namespace
-
-void mlir::math::populateLegalizeToF32TypeConverter(
-    TypeConverter &typeConverter) {
-  typeConverter.addConversion(
-      [](Type type) -> std::optional<Type> { return type; });
-  typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
-    if (type.getWidth() < 32)
-      return Float32Type::get(type.getContext());
-    return std::nullopt;
-  });
-  typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
-    if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
-      return type.clone(Float32Type::get(type.getContext()));
-    return std::nullopt;
-  });
-  typeConverter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
-}
-
-void mlir::math::populateLegalizeToF32ConversionTarget(
-    ConversionTarget &target, TypeConverter &typeConverter) {
-  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
-    if (isa<MathDialect>(op->getDialect()))
-      return typeConverter.isLegal(op);
-    return true;
-  });
-  target.addLegalOp<FmaOp>();
-  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
-}
-
-LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  Location loc = op->getLoc();
-  const TypeConverter *converter = getTypeConverter();
-  FailureOr<Operation *> legalized =
-      convertOpResultTypes(op, operands, *converter, rewriter);
-  if (failed(legalized))
-    return failure();
-
-  SmallVector<Value> results = (*legalized)->getResults();
-  for (auto [result, newType, origType] : llvm::zip_equal(
-           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      result = truncFOp.getResult();
-    }
-  }
-  rewriter.replaceOp(op, results);
-  return success();
-}
-
-void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
-                                               TypeConverter &typeConverter) {
-  patterns.add<LegalizeToF32RewritePattern>(typeConverter,
-                                            patterns.getContext());
-}
-
-void LegalizeToF32Pass::runOnOperation() {
-  Operation *op = getOperation();
-  MLIRContext &ctx = getContext();
-
-  TypeConverter typeConverter;
-  math::populateLegalizeToF32TypeConverter(typeConverter);
-  ConversionTarget target(ctx);
-  math::populateLegalizeToF32ConversionTarget(target, typeConverter);
-  RewritePatternSet patterns(&ctx);
-  math::populateLegalizeToF32Patterns(patterns, typeConverter);
-  if (failed(applyPartialConversion(op, target, std::move(patterns))))
-    return signalPassFailure();
-}
diff --git a/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
new file mode 100644
index 00000000000000..6a61b3ba6251fa
--- /dev/null
+++ b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="extra-types=f16 target-type=f32" | FileCheck %s
+
+// CHECK-LABEL: @sin_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+func.func @sin_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+  // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+  // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+  // CHECK: return [[TRUNCF]] : f8E5M2
+  %0 = math.sin %arg0 : f8E5M2
+  return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+  // CHECK16: [[SIN:%.+]] = math.sin [[ARG0]] : f16
+  // CHECK16: return [[SIN]] : f16
+  %0 = math.sin %arg0 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2, [[ARG1:%.+]]: i32)
+func.func @fpowi_f8...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/108815


More information about the Mlir-commits mailing list