[Mlir-commits] [mlir] 589764a - [mlir][math] Initial support for fastmath flag attributes for Math dialect.
Slava Zakharin
llvmlistbot at llvm.org
Fri Nov 4 10:44:27 PDT 2022
Author: Slava Zakharin
Date: 2022-11-04T10:41:56-07:00
New Revision: 589764a382642ae8374cfe21a6b10f839c8596da
URL: https://github.com/llvm/llvm-project/commit/589764a382642ae8374cfe21a6b10f839c8596da
DIFF: https://github.com/llvm/llvm-project/commit/589764a382642ae8374cfe21a6b10f839c8596da.diff
LOG: [mlir][math] Initial support for fastmath flag attributes for Math dialect.
Added arith::FastMathAttr and ArithFastMathInterface support for Math dialect
floating point operations.
This change-set creates ArithCommon conversion utils that currently
provide classes and methods to aid with arith::FastMathAttr conversion
into LLVM::FastmathFlags. These utils are used in ArithToLLVM and
MathToLLVM convertors, but may eventually be used by other converters
that need to convert fast math attributes.
Since Math dialect operations use arith::FastMathAttr, MathOps.td now
has to include enum and attributes definitions from Arith dialect.
To minimize the amount of TD code included from Arith dialect,
I moved FastMathAttr definition into ArithBase.td.
Differential Revision: https://reviews.llvm.org/D136312
Added:
mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
mlir/lib/Conversion/ArithCommon/CMakeLists.txt
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/Dialect/Math/IR/Math.h
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Dialect/Math/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
new file mode 100644
index 0000000000000..f27f7bb5975ec
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -0,0 +1,81 @@
+//===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
+#define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+
+//===----------------------------------------------------------------------===//
+// Support for converting Arith FastMathFlags to LLVM FastmathFlags
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace arith {
+// Map arithmetic fastmath enum values to LLVMIR enum values.
+LLVM::FastmathFlags
+convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
+
+// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+LLVM::FastmathFlagsAttr
+convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertFastMathToLLVM {
+public:
+ AttrConvertFastMathToLLVM(SourceOp srcOp) {
+ // Copy the source attributes.
+ convertedAttr = NamedAttrList{srcOp->getAttrs()};
+ // Get the name of the arith fastmath attribute.
+ llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+ // Remove the source fastmath attribute.
+ auto arithFMFAttr =
+ convertedAttr.erase(arithFMFAttrName)
+ .template dyn_cast_or_null<arith::FastMathFlagsAttr>();
+ if (arithFMFAttr) {
+ llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+ convertedAttr.set(targetAttrName,
+ convertArithFastMathAttrToLLVM(arithFMFAttr));
+ }
+ }
+
+ ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+ NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes. This may be useful for
+// target operations that do not require the fastmath attribute, or for targets
+// that do not yet support the LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrDropFastMath {
+public:
+ AttrDropFastMath(SourceOp srcOp) {
+ // Copy the source attributes.
+ convertedAttr = NamedAttrList{srcOp->getAttrs()};
+ // Get the name of the arith fastmath attribute.
+ llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+ // Remove the source fastmath attribute.
+ convertedAttr.erase(arithFMFAttrName);
+ }
+
+ ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+ NamedAttrList convertedAttr;
+};
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 13d252cf056e5..78fd7bdf012f8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -121,4 +121,9 @@ def FastMathFlags : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+def Arith_FastMathAttr :
+ EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6ca74392f0565..3d6cef9705ebe 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -20,11 +20,6 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
-def Arith_FastMathAttr :
- EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
// Base class for Arith dialect ops. Ops in this dialect have no side
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<Trait> traits = []> :
diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h
index 6af358bf57b37..98416d1c9abdf 100644
--- a/mlir/include/mlir/Dialect/Math/IR/Math.h
+++ b/mlir/include/mlir/Dialect/Math/IR/Math.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_IR_MATH_H_
#define MLIR_DIALECT_MATH_IR_MATH_H_
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 99e209000c0f5..a5b28bd0891c5 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -9,6 +9,8 @@
#ifndef MATH_OPS
#define MATH_OPS
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Math/IR/MathBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/VectorInterfaces.td"
@@ -36,11 +38,16 @@ class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
// operand and result of the same type. This type can be a floating point type,
// vector or tensor thereof.
class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
- Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins FloatLike:$operand);
+ Math_Op<mnemonic,
+ traits # [SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let arguments = (ins FloatLike:$operand,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
- let assemblyFormat = "$operand attr-dict `:` type($result)";
+ let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
}
// Base class for binary math operations on integer types. Require two
@@ -58,22 +65,32 @@ class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
// operands and one result of the same type. This type can be a floating point
// type, vector or tensor thereof.
class Math_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
- Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+ Math_Op<mnemonic,
+ traits # [SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
+ let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
}
// Base class for floating point ternary operations. Require three operands and
// one result of the same type. This type can be a floating point type, vector
// or tensor thereof.
class Math_FloatTernaryOp<string mnemonic, list<Trait> traits = []> :
- Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c);
+ Math_Op<mnemonic,
+ traits # [SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
- let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
+ let assemblyFormat = [{ $a `,` $b `,` $c (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
new file mode 100644
index 0000000000000..8c5d76f9f2d72
--- /dev/null
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -0,0 +1,38 @@
+//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+
+using namespace mlir;
+
+// Map arithmetic fastmath enum values to LLVMIR enum values.
+LLVM::FastmathFlags
+mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
+ LLVM::FastmathFlags llvmFMF{};
+ const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
+ {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
+ {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
+ {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
+ {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
+ {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
+ {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
+ {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
+ for (auto fmfMap : flags) {
+ if (bitEnumContainsAny(arithFMF, fmfMap.first))
+ llvmFMF = llvmFMF | fmfMap.second;
+ }
+ return llvmFMF;
+}
+
+// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+LLVM::FastmathFlagsAttr
+mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
+ arith::FastMathFlags arithFMF = fmfAttr.getValue();
+ return LLVM::FastmathFlagsAttr::get(
+ fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
+}
diff --git a/mlir/lib/Conversion/ArithCommon/CMakeLists.txt b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt
new file mode 100644
index 0000000000000..888c45f2e52fe
--- /dev/null
+++ b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_conversion_library(MLIRArithAttrToLLVMConversion
+ AttrToLLVMConverter.cpp
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f2814b56d4d34..1409b7fe1bca8 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -8,6 +8,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -24,93 +25,20 @@ using namespace mlir;
namespace {
-// Map arithmetic fastmath enum values to LLVMIR enum values.
-static LLVM::FastmathFlags
-convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
- LLVM::FastmathFlags llvmFMF{};
- const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
- {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
- {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
- {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
- {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
- {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
- {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
- {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
- for (auto fmfMap : flags) {
- if (bitEnumContainsAny(arithFMF, fmfMap.first))
- llvmFMF = llvmFMF | fmfMap.second;
- }
- return llvmFMF;
-}
-
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
-static LLVM::FastmathFlagsAttr
-convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
- arith::FastMathFlags arithFMF = fmfAttr.getValue();
- return LLVM::FastmathFlagsAttr::get(
- fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
-}
-
-// Attribute converter that populates a NamedAttrList by removing the fastmath
-// attribute from the source operation attributes, and replacing it with an
-// equivalent LLVM fastmath attribute.
-template <typename SourceOp, typename TargetOp>
-class AttrConvertFastMath {
-public:
- AttrConvertFastMath(SourceOp srcOp) {
- // Copy the source attributes.
- convertedAttr = NamedAttrList{srcOp->getAttrs()};
- // Get the name of the arith fastmath attribute.
- llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
- // Remove the source fastmath attribute.
- auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
- .template dyn_cast_or_null<arith::FastMathFlagsAttr>();
- if (arithFMFAttr) {
- llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
- convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
- }
- }
-
- ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
-
-private:
- NamedAttrList convertedAttr;
-};
-
-// Attribute converter that populates a NamedAttrList by removing the fastmath
-// attribute from the source operation attributes. This may be useful for
-// target operations that do not require the fastmath attribute, or for targets
-// that do not yet support the LLVM fastmath attribute.
-template <typename SourceOp, typename TargetOp>
-class AttrDropFastMath {
-public:
- AttrDropFastMath(SourceOp srcOp) {
- // Copy the source attributes.
- convertedAttr = NamedAttrList{srcOp->getAttrs()};
- // Get the name of the arith fastmath attribute.
- llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
- // Remove the source fastmath attribute.
- convertedAttr.erase(arithFMFAttrName);
- }
-
- ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
-
-private:
- NamedAttrList convertedAttr;
-};
-
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
-using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- AttrConvertFastMath>;
+using AddFOpLowering =
+ VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+ arith::AttrConvertFastMathToLLVM>;
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
-using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- AttrConvertFastMath>;
+using DivFOpLowering =
+ VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+ arith::AttrConvertFastMathToLLVM>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
@@ -125,28 +53,30 @@ using FPToSIOpLowering =
using FPToUIOpLowering =
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
// TODO: Add LLVM intrinsic support for fastmath
-using MaxFOpLowering =
- VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
+using MaxFOpLowering = VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
+ arith::AttrDropFastMath>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
// TODO: Add LLVM intrinsic support for fastmath
-using MinFOpLowering =
- VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
+using MinFOpLowering = VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
+ arith::AttrDropFastMath>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
-using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- AttrConvertFastMath>;
+using MulFOpLowering =
+ VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+ arith::AttrConvertFastMathToLLVM>;
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
-using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- AttrConvertFastMath>;
+using NegFOpLowering =
+ VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
+ arith::AttrConvertFastMathToLLVM>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
// TODO: Add LLVM intrinsic support for fastmath
-using RemFOpLowering =
- VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
+using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+ arith::AttrDropFastMath>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -160,8 +90,9 @@ using ShRUIOpLowering =
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
-using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- AttrConvertFastMath>;
+using SubFOpLowering =
+ VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+ arith::AttrConvertFastMathToLLVM>;
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
using TruncFOpLowering =
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
diff --git a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
index 45ee8708aa155..bb1fa2fbb6577 100644
--- a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToLLVM
Core
LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
MLIRArithDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index a65814d36b5b4..62dae19a31344 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
+add_subdirectory(ArithCommon)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
index a6e6b4f56d37e..97393fc849691 100644
--- a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToLLVM
Core
LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRMathDialect
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index b67a86f443b5c..b5ce019b20832 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -8,6 +8,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@@ -24,31 +25,39 @@ namespace mlir {
using namespace mlir;
namespace {
-using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
-using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
+
+template <typename SourceOp, typename TargetOp>
+using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
+
+template <typename SourceOp, typename TargetOp>
+using ConvertFMFMathToLLVMPattern =
+ VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+
+using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
+using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
- VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
-using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
+ ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
+using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CtPopFOpLowering =
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
-using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
-using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
+using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
+using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
- VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
+ ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
+using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
using Log10OpLowering =
- VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
-using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
-using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
-using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+ ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
+using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
+using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
+using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RoundEvenOpLowering =
- VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
+ ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
using RoundOpLowering =
- VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
-using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
-using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+ ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
+using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
+using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using FTruncOpLowering =
- VectorConvertToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
+ ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
template <typename MathOp, typename LLVMOp>
@@ -113,6 +122,8 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+ ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
+ ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
@@ -123,8 +134,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
- auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
- rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
+ auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
+ expAttrs.getAttrs());
+ rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
+ op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
}
@@ -142,9 +155,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto exp =
- rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
- return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
+ auto exp = rewriter.create<LLVM::ExpOp>(
+ loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
+ return rewriter.create<LLVM::FSubOp>(
+ loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
},
rewriter);
}
@@ -166,6 +180,8 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+ ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
+ ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one =
@@ -176,9 +192,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
floatOne))
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
- auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
- adaptor.getOperand());
- rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
+ auto add = rewriter.create<LLVM::FAddOp>(
+ loc, operandType, ValueRange{one, adaptor.getOperand()},
+ addAttrs.getAttrs());
+ rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
+ logAttrs.getAttrs());
return success();
}
@@ -196,9 +214,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
- operands[0]);
- return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
+ auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
+ ValueRange{one, operands[0]},
+ addAttrs.getAttrs());
+ return rewriter.create<LLVM::LogOp>(
+ loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
},
rewriter);
}
@@ -220,6 +240,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+ ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
+ ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
@@ -230,8 +252,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
+ auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
+ sqrtAttrs.getAttrs());
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
+ op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
}
@@ -249,9 +273,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto sqrt =
- rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
- return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
+ auto sqrt = rewriter.create<LLVM::SqrtOp>(
+ loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
+ return rewriter.create<LLVM::FDivOp>(
+ loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
},
rewriter);
}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index bcdbad1709e93..8c7f031cb97d9 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -36,6 +36,18 @@ func.func @log1p(%arg0 : f32) {
// -----
+// CHECK-LABEL: func @log1p_fmf(
+// CHECK-SAME: f32
+func.func @log1p_fmf(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ %0 = math.log1p %arg0 fastmath<fast> : f32
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @log1p_2dvector(
func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -49,6 +61,19 @@ func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @log1p_2dvector_fmf(
+func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<3xf32>
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<3xf32>) -> vector<3xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ %0 = math.log1p %arg0 fastmath<fast> : vector<4x3xf32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @expm1(
// CHECK-SAME: f32
func.func @expm1(%arg0 : f32) {
@@ -61,6 +86,42 @@ func.func @expm1(%arg0 : f32) {
// -----
+// CHECK-LABEL: func @expm1_fmf(
+// CHECK-SAME: f32
+func.func @expm1_fmf(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath<fast>} : f32
+ %0 = math.expm1 %arg0 fastmath<fast> : f32
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @expm1_vector(
+// CHECK-SAME: vector<4xf32>
+func.func @expm1_vector(%arg0 : vector<4xf32>) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) : (vector<4xf32>) -> vector<4xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<4xf32>
+ %0 = math.expm1 %arg0 : vector<4xf32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @expm1_vector_fmf(
+// CHECK-SAME: vector<4xf32>
+func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf32>) -> vector<4xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath<fast>} : vector<4xf32>
+ %0 = math.expm1 %arg0 fastmath<fast> : vector<4xf32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt(
// CHECK-SAME: f32
func.func @rsqrt(%arg0 : f32) {
@@ -148,6 +209,18 @@ func.func @rsqrt_double(%arg0 : f64) {
// -----
+// CHECK-LABEL: func @rsqrt_double_fmf(
+// CHECK-SAME: f64
+func.func @rsqrt_double_fmf(%arg0 : f64) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : f64
+ %0 = math.rsqrt %arg0 fastmath<fast> : f64
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_vector(
// CHECK-SAME: vector<4xf32>
func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
@@ -160,6 +233,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_vector_fmf(
+// CHECK-SAME: vector<4xf32>
+func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf32>) -> vector<4xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<4xf32>
+ %0 = math.rsqrt %arg0 fastmath<fast> : vector<4xf32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_multidim_vector(
func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -210,3 +295,19 @@ func.func @trunc(%arg0 : f32) {
%0 = math.trunc %arg0 : f32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @fastmath(
+// CHECK-SAME: f32
+func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
+ // CHECK: llvm.intr.trunc(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ %0 = math.trunc %arg0 fastmath<fast> : f32
+ // CHECK: llvm.intr.pow(%arg0, %arg0) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+ %1 = math.powf %arg0, %arg0 fastmath<afn> : f32
+ // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32
+ %2 = math.sqrt %arg0 fastmath<none> : f32
+ // CHECK: llvm.intr.fma(%arg0, %arg0, %arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
+ %3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
+ func.return
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index d984cbb66f8c2..7e121f80dd79e 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -269,3 +269,17 @@ func.func @trunc(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
%2 = math.trunc %t : tensor<4x4x?xf32>
return
}
+
+// CHECK-LABEL: func @fastmath(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @fastmath(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.trunc %[[F]] fastmath<fast> : f32
+ %0 = math.trunc %f fastmath<fast> : f32
+ // CHECK: %{{.*}} = math.powf %[[V]], %[[V]] fastmath<fast> : vector<4xf32>
+ %1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : vector<4xf32>
+ // CHECK: %{{.*}} = math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32>
+ %2 = math.fma %t, %t, %t fastmath<none> : tensor<4x4x?xf32>
+ // CHECK: %{{.*}} = math.absf %[[F]] fastmath<ninf> : f32
+ %3 = math.absf %f fastmath<ninf> : f32
+ return
+}
More information about the Mlir-commits
mailing list