[Mlir-commits] [mlir] [MLIR][Arith] Add denormal attribute to binary/unary operations (PR #112700)
lorenzo chelini
llvmlistbot at llvm.org
Fri Nov 22 06:58:11 PST 2024
https://github.com/chelini updated https://github.com/llvm/llvm-project/pull/112700
>From a85ce4090ec3c5c1a1b36c56c0726230f8b4596e Mon Sep 17 00:00:00 2001
From: lorenzo chelini <lchelini at nvidia.com>
Date: Wed, 20 Nov 2024 20:21:01 +0100
Subject: [PATCH 1/4] [MLIR][Arith] Add denormal attribute to binary/unary
operations
Add support for denormal in the Arith dialect (binary and unary operations).
Denormal are attached to every operation, and they can be of three different
kinds:
1) ieee, denormal are preserved and processed as defined by IEEE 754 rules.
2) preserve sign, a mode where denormal numbers are flushed to zero, but the
sign of the zero (+0 or -0) is preserved.
3) positive zero, a mode where all denormal numbers are flushed to positive zero
(+0), ignoring the sign of the original number.
Denormal refers to both the operands and the result.
---
.../ArithCommon/AttrToLLVMConverter.h | 6 +-
.../mlir/Dialect/Arith/IR/ArithBase.td | 34 +++++++++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 22 +++---
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 40 ++++++++++-
mlir/include/mlir/IR/Matchers.h | 6 ++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 67 ++++++++++++------
.../Dialect/Arith/IR/ArithCanonicalization.td | 14 ++--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 24 ++++++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 +++--
mlir/test/CAPI/ir.c | 2 +-
mlir/test/Dialect/Arith/canonicalize.mlir | 23 ++++++
mlir/test/Dialect/Arith/invalid.mlir | 8 +++
mlir/test/Dialect/Arith/ops.mlir | 70 +++++++++++++++++++
mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
14 files changed, 286 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 7ffc8613317603..da067410db5eff 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
template <typename SourceOp, typename TargetOp>
class AttrConvertFastMathToLLVM {
public:
- AttrConvertFastMathToLLVM(SourceOp srcOp) {
+ explicit AttrConvertFastMathToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
@@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM {
template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
- AttrConvertOverflowToLLVM(SourceOp srcOp) {
+ explicit AttrConvertOverflowToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
@@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM {
"LLVM::FPExceptionBehaviorOpInterface");
public:
- AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
+ explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0e..4309c0618667a8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -181,4 +181,38 @@ def Arith_RoundingModeAttr : I32EnumAttr<
let cppNamespace = "::mlir::arith";
}
+//===----------------------------------------------------------------------===//
+// Arith_DenormalMode
+//===----------------------------------------------------------------------===//
+
+// Denormal mode is applied on operands and results. For example, if denormal =
+// preserve_sign, operands and results will be flushed to sign preserving zero.
+// We do not distinguish between operands and results.
+
+// The default mode. Denormals are preserved and processed as defined
+// by IEEE 754 rules.
+def Arith_DenormalModeIEEE : I32BitEnumAttrCaseNone<"ieee">;
+
+// A mode where denormal numbers are flushed to zero, but the sign of the zero
+// (+0 or -0) is preserved.
+def Arith_DenormalModePreserveSign : I32BitEnumAttrCase<"preserve_sign", 1>;
+
+// A mode where all denormal numbers are flushed to positive zero (+0),
+// ignoring the sign of the original number.
+def Arith_DenormalModePositiveZero : I32BitEnumAttrCase<"positive_zero", 2>;
+
+def Arith_DenormalMode : I32BitEnumAttr<
+ "DenormalMode", "denormal mode arith",
+ [Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign,
+ Arith_DenormalModePositiveZero]> {
+ let cppNamespace = "::mlir::arith";
+ let genSpecializedAttr = 0;
+}
+
+def Arith_DenormalModeAttr :
+ EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
+ let assemblyFormat = "`<` $value `>`";
+ let genVerifyDecl = 1;
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 19a5e13a5d755d..4069e43af82e8e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -61,26 +61,35 @@ class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
// Base class for floating point unary operations.
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
Arith_UnaryOp<mnemonic,
- !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
+ !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
traits)>,
Arguments<(ins FloatLike:$operand,
DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
+ Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+ DefaultValuedAttr<
+ Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
Results<(outs FloatLike:$result)> {
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+ (`denormal` `` $denormal^)?
attr-dict `:` type($result) }];
}
// Base class for floating point binary operations.
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic,
- !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
+ !listconcat([Pure,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
traits)>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
+ Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+ DefaultValuedAttr<
+ Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
Results<(outs FloatLike:$result)> {
- let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
+ let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
+ (`denormal` `` $denormal^)?
attr-dict `:` type($result) }];
}
@@ -1085,7 +1094,6 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
let hasFolder = 1;
}
-
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
@@ -1111,8 +1119,6 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
%x = arith.mulf %y, %z : tensor<4x?xbf16>
```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..270d80f2ec73af 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -45,13 +45,12 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
return "fastmath";
}]
>
-
];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
let description = [{
- Access to op integer overflow flags.
+ Access to operation integer overflow flags.
}];
let cppNamespace = "::mlir::arith";
@@ -108,7 +107,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
let description = [{
- Access to op rounding mode.
+ Access to operation rounding mode.
}];
let cppNamespace = "::mlir::arith";
@@ -139,4 +138,39 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
];
}
+
+def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> {
+ let description = [{
+ Access the operation denormal modes.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a DenormalModeAttr attribute for the operation",
+ /*returnType=*/ "DenormalModeAttr",
+ /*methodName=*/ "getDenormalModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getDenormalAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the DenormalModeAttr attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getDenormalModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "denormal";
+ }]
+ >
+ ];
+}
+
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..226afb9ad25f1a 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -433,6 +433,12 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() {
}};
}
+/// Matches a constant scalar / vector splat / tensor splat with denormal
+/// values.
+inline detail::constant_float_predicate_matcher m_isDenormalFloat() {
+ return {[](const APFloat &value) { return value.isDenormal(); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_predicate_matcher m_Zero() {
return {[](const APInt &value) { return 0 == value; }};
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index aac24f113d891f..54d941ae9f6c89 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -53,13 +53,40 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};
+template <typename SourceOp, typename TargetOp,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough>
+struct DenormalOpConversionToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+ using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO: Here, we need a legalization step. LLVM provides a function-level
+ // attribute for denormal; here, we need to move this information from the
+ // operation to the function, making sure all the operations in the same
+ // function are consistent.
+ if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee)
+ return rewriter.notifyMatchFailure(
+ op, "only ieee denormal mode is supported at the moment");
+
+ StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName();
+ op->removeAttr(arithDenormalAttrName);
+ return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::matchAndRewrite(op, adaptor,
+ rewriter);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
using AddFOpLowering =
- VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+ arith::AttrConvertFastMathToLLVM>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -67,8 +94,8 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
- VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+ arith::AttrConvertFastMathToLLVM>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
@@ -83,38 +110,38 @@ using FPToSIOpLowering =
using FPToUIOpLowering =
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
using MaximumFOpLowering =
- VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MaxNumFOpLowering =
- VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
- VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MinNumFOpLowering =
- VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
- VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+ arith::AttrConvertFastMathToLLVM>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
- VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
+ arith::AttrConvertFastMathToLLVM>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
- VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+ arith::AttrConvertFastMathToLLVM>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -131,8 +158,8 @@ using ShRUIOpLowering =
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
- VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM>;
+ DenormalOpConversionToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+ arith::AttrConvertFastMathToLLVM>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd7..22c34b2bd42f58 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -422,10 +422,11 @@ def TruncIShrUIMulIToMulUIExtended :
//===----------------------------------------------------------------------===//
// mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and denormal mode of the original divf)
def MulFOfNegF :
- Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_MulFOp $x, $y, $fmf),
+ Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_),
+ (Arith_NegFOp $y, $_, $_), $fmf, $mode),
+ (Arith_MulFOp $x, $y, $fmf, $mode),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
//===----------------------------------------------------------------------===//
@@ -433,10 +434,11 @@ def MulFOfNegF :
//===----------------------------------------------------------------------===//
// divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and denormal mode of the original divf)
def DivFOfNegF :
- Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_DivFOp $x, $y, $fmf),
+ Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_),
+ (Arith_NegFOp $y, $_, $_), $fmf, $mode),
+ (Arith_DivFOp $x, $y, $fmf, $mode),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
#endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..1b8a459c6e8c4b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -952,7 +952,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
- /// negf(negf(x)) -> x
+ // negf(negf(x)) -> x
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
return op.getOperand();
return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
@@ -982,6 +982,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();
+ // Simplifies subf(x, rhs) to x if the following conditions are met:
+ // 1. `rhs` is a denormal floating-point value.
+ // 2. The denormal mode for the operation is set to positive zero.
+ bool isPositiveZeroMode =
+ getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
+ if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
+ return getLhs();
+
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a - b; });
@@ -2635,6 +2643,20 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// DenormalModeAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult DenormalModeAttr::verify(
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ DenormalMode mode) {
+ auto value = static_cast<uint32_t>(mode);
+ bool isSingleBitSet = (value & (value - 1)) == 0;
+ if (!isSingleBitSet)
+ return emitError() << "expected only a single denormal mode";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 26d9d2b091750c..ce614208bef5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1501,12 +1501,17 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
for (const auto &attr : payloadOp->getAttrs()) {
- auto fastAttr =
- llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
- if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
- attrToElide = attr.getName().str();
- elidedAttrs.push_back(attrToElide);
- break;
+ if (auto fastAttr =
+ llvm::dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
+ if (fastAttr.getValue() == arith::FastMathFlags::none) {
+ elidedAttrs.push_back(attr.getName().str());
+ }
+ }
+ if (auto denormAttr =
+ llvm::dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
+ if (denormAttr.getValue() == arith::DenormalMode::ieee) {
+ elidedAttrs.push_back(attr.getName().str());
+ }
}
}
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 15a3a1fb50dc9e..fa3b0d894c995b 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -320,7 +320,7 @@ int collectStats(MlirOperation operation) {
// clang-format off
// CHECK-LABEL: @stats
// CHECK: Number of operations: 12
- // CHECK: Number of attributes: 5
+ // CHECK: Number of attributes: 6
// CHECK: Number of blocks: 3
// CHECK: Number of regions: 3
// CHECK: Number of values: 9
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..f56bf0980b13c1 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3189,3 +3189,26 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
}
}
#-}
+
+// -----
+
+// CHECK-LABEL: @test_fold_denorm
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_fold_denorm(%arg0: f32) -> f32 {
+ // CHECK-NOT: arith.subf
+ // CHECK: return %[[ARG0]] : f32
+ %c_denorm = arith.constant 1.4e-45 : f32
+ %sub = arith.subf %arg0, %c_denorm denormal<positive_zero> : f32
+ return %sub : f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_expect_not_to_fold_denorm
+func.func @test_expect_not_to_fold_denorm(%arg0: f32, %arg1 : f32) -> (f32, f32) {
+ // CHECK-COUNT-2: arith.subf
+ %c_denorm = arith.constant 1.4e-45 : f32
+ %sub = arith.subf %arg0, %c_denorm denormal<ieee> : f32
+ %sub_1 = arith.subf %arg1, %c_denorm denormal<preserve_sign> : f32
+ return %sub, %sub_1 : f32, f32
+}
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 088da475e8eb4c..4999008e572fc9 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -853,3 +853,11 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}
+
+// -----
+
+func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 {
+ // expected-error @below{{expected only a single denormal mode}}
+ %0 = arith.subf %arg0, %arg1 denormal<preserve_sign | positive_zero> : f32
+ return %0 : f32
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..c019974020879f 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1161,3 +1161,73 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
+
+// CHECK-LABEL: check_denorm_modes
+func.func @check_denorm_modes(%arg0: f32, %arg1: f32, %arg2: f32) {
+ %c_denorm = arith.constant 1.4e-45 : f32
+ // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %sub_preserve_sign = arith.subf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %sub_positive_zero = arith.subf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} : f32
+ %sub_ieee = arith.subf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %add_preserve_sign = arith.addf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %add_positive_zero = arith.addf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} : f32
+ %add_ieee = arith.addf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %mul_preserve_sign = arith.mulf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %mul_positive_zero = arith.mulf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} : f32
+ %mul_ieee = arith.mulf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %div_preserve_sign = arith.divf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %div_positive_zero = arith.divf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} : f32
+ %div_ieee = arith.divf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %maximumf_preserve_sign = arith.maximumf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %maximumf_positive_zero = arith.maximumf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} : f32
+ %maximumf_ieee = arith.maximumf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %maxnumf_preserve_sign = arith.maxnumf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %maxnumf_positive_zero = arith.maxnumf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} : f32
+ %maxnumf_ieee = arith.maxnumf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %minimumf_preserve_sign = arith.minimumf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %minimumf_positive_zero = arith.minimumf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} : f32
+ %minimumf_ieee = arith.minimumf %arg2, %c_denorm denormal<ieee> : f32
+
+ // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+ %minnumf_preserve_sign = arith.minnumf %arg0, %c_denorm denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+ %minnumf_positive_zero = arith.minnumf %arg1, %c_denorm denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} : f32
+ %minnumf_ieee = arith.minnumf %arg2, %c_denorm denormal<ieee> : f32
+
+
+ // CHECK: %{{.+}} = arith.negf %{{.+}} denormal<preserve_sign> : f32
+ %negf_preserve_sign = arith.negf %arg0 denormal<preserve_sign> : f32
+ // CHECK: %{{.+}} = arith.negf %{{.+}} denormal<positive_zero> : f32
+ %negf_positive_sign = arith.negf %arg0 denormal<positive_zero> : f32
+ // CHECK: %{{.+}} = arith.negf %{{.+}} : f32
+ %negf_ieee = arith.negf %arg0 denormal<ieee> : f32
+
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..e3b6958cfa8816 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
// -----
func.func @generic(%arg0: memref<?x?xf32>) {
- // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
+ // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{denormal = #arith.denormal<ieee>, fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
linalg.generic {
indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
iterator_types = ["parallel", "parallel"]}
>From e44b3e889769aede3aae42e7b0c10069d5ee7897 Mon Sep 17 00:00:00 2001
From: lorenzo chelini <lchelini at nvidia.com>
Date: Thu, 21 Nov 2024 20:40:03 +0100
Subject: [PATCH 2/4] new line
---
mlir/test/Dialect/Arith/invalid.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 4999008e572fc9..ca86d51fd3523d 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -860,4 +860,4 @@ func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 {
// expected-error @below{{expected only a single denormal mode}}
%0 = arith.subf %arg0, %arg1 denormal<preserve_sign | positive_zero> : f32
return %0 : f32
-}
\ No newline at end of file
+}
>From a50d31cb16619d523999a9a423b335257ab1a4d8 Mon Sep 17 00:00:00 2001
From: lorenzo chelini <lchelini at nvidia.com>
Date: Thu, 21 Nov 2024 21:37:18 +0100
Subject: [PATCH 3/4] plain enum
---
mlir/include/mlir/Dialect/Arith/IR/ArithBase.td | 9 ++++-----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 14 --------------
mlir/test/Dialect/Arith/invalid.mlir | 8 --------
3 files changed, 4 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 4309c0618667a8..d27ea5edcc8c8d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -191,17 +191,17 @@ def Arith_RoundingModeAttr : I32EnumAttr<
// The default mode. Denormals are preserved and processed as defined
// by IEEE 754 rules.
-def Arith_DenormalModeIEEE : I32BitEnumAttrCaseNone<"ieee">;
+def Arith_DenormalModeIEEE : I32EnumAttrCase<"ieee", 0>;
// A mode where denormal numbers are flushed to zero, but the sign of the zero
// (+0 or -0) is preserved.
-def Arith_DenormalModePreserveSign : I32BitEnumAttrCase<"preserve_sign", 1>;
+def Arith_DenormalModePreserveSign : I32EnumAttrCase<"preserve_sign", 1>;
// A mode where all denormal numbers are flushed to positive zero (+0),
// ignoring the sign of the original number.
-def Arith_DenormalModePositiveZero : I32BitEnumAttrCase<"positive_zero", 2>;
+def Arith_DenormalModePositiveZero : I32EnumAttrCase<"positive_zero", 2>;
-def Arith_DenormalMode : I32BitEnumAttr<
+def Arith_DenormalMode : I32EnumAttr<
"DenormalMode", "denormal mode arith",
[Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign,
Arith_DenormalModePositiveZero]> {
@@ -212,7 +212,6 @@ def Arith_DenormalMode : I32BitEnumAttr<
def Arith_DenormalModeAttr :
EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
let assemblyFormat = "`<` $value `>`";
- let genVerifyDecl = 1;
}
#endif // ARITH_BASE
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1b8a459c6e8c4b..47766f36ad05cf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2643,20 +2643,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return nullptr;
}
-//===----------------------------------------------------------------------===//
-// DenormalModeAttr
-//===----------------------------------------------------------------------===//
-
-LogicalResult DenormalModeAttr::verify(
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
- DenormalMode mode) {
- auto value = static_cast<uint32_t>(mode);
- bool isSingleBitSet = (value & (value - 1)) == 0;
- if (!isSingleBitSet)
- return emitError() << "expected only a single denormal mode";
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ca86d51fd3523d..088da475e8eb4c 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -853,11 +853,3 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}
-
-// -----
-
-func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 {
- // expected-error @below{{expected only a single denormal mode}}
- %0 = arith.subf %arg0, %arg1 denormal<preserve_sign | positive_zero> : f32
- return %0 : f32
-}
>From 9fda8198fea96cf160eaddd32a5cdd3d40023f5e Mon Sep 17 00:00:00 2001
From: lorenzo chelini <lchelini at nvidia.com>
Date: Fri, 22 Nov 2024 15:56:40 +0100
Subject: [PATCH 4/4] drop llvm
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ce614208bef5cc..98810c5f19d798 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1502,13 +1502,13 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
p << " { " << payloadOp->getName().getStringRef();
for (const auto &attr : payloadOp->getAttrs()) {
if (auto fastAttr =
- llvm::dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
+ dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
if (fastAttr.getValue() == arith::FastMathFlags::none) {
elidedAttrs.push_back(attr.getName().str());
}
}
if (auto denormAttr =
- llvm::dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
+ dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
if (denormAttr.getValue() == arith::DenormalMode::ieee) {
elidedAttrs.push_back(attr.getName().str());
}
More information about the Mlir-commits
mailing list