[Mlir-commits] [mlir] 4a7b56e - [MLIR][Arith] Add denormal attribute to binary/unary operations (#112700)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 02:58:47 PST 2024


Author: lorenzo chelini
Date: 2024-11-26T11:58:43+01:00
New Revision: 4a7b56e6e7dd0f83c379ad06b6e81450bc691ba6

URL: https://github.com/llvm/llvm-project/commit/4a7b56e6e7dd0f83c379ad06b6e81450bc691ba6
DIFF: https://github.com/llvm/llvm-project/commit/4a7b56e6e7dd0f83c379ad06b6e81450bc691ba6.diff

LOG: [MLIR][Arith] Add denormal attribute to binary/unary operations (#112700)

Add support for denormal in the Arith dialect (binary and unary
operations).
Denormal are attached to every operation, and they can be of three
different
kinds:

1) ieee, denormal are preserved and processed as defined by IEEE 754
rules.

2) preserve sign, a mode where denormal numbers are flushed to zero, but
the
sign of the zero (+0 or -0) is preserved.

3) positive zero, a mode where all denormal numbers are flushed to
positive zero
(+0), ignoring the sign of the original number.

Denormal refers to both the operands and the result. Currently only
lowering for
ieee is supported.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
    mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
    mlir/include/mlir/IR/Matchers.h
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/CAPI/ir.c
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 7ffc8613317603..da067410db5eff 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
 template <typename SourceOp, typename TargetOp>
 class AttrConvertFastMathToLLVM {
 public:
-  AttrConvertFastMathToLLVM(SourceOp srcOp) {
+  explicit AttrConvertFastMathToLLVM(SourceOp srcOp) {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
     // Get the name of the arith fastmath attribute.
@@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM {
 template <typename SourceOp, typename TargetOp>
 class AttrConvertOverflowToLLVM {
 public:
-  AttrConvertOverflowToLLVM(SourceOp srcOp) {
+  explicit AttrConvertOverflowToLLVM(SourceOp srcOp) {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
     // Get the name of the arith overflow attribute.
@@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM {
                 "LLVM::FPExceptionBehaviorOpInterface");
 
 public:
-  AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
+  explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
 

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0e..d27ea5edcc8c8d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -181,4 +181,37 @@ def Arith_RoundingModeAttr : I32EnumAttr<
   let cppNamespace = "::mlir::arith";
 }
 
+//===----------------------------------------------------------------------===//
+// Arith_DenormalMode
+//===----------------------------------------------------------------------===//
+
+// Denormal mode is applied on operands and results. For example, if denormal =
+// preserve_sign, operands and results will be flushed to sign preserving zero.
+// We do not distinguish between operands and results.
+
+// The default mode. Denormals are preserved and processed as defined 
+// by IEEE 754 rules.
+def Arith_DenormalModeIEEE          : I32EnumAttrCase<"ieee", 0>;
+
+// A mode where denormal numbers are flushed to zero, but the sign of the zero 
+// (+0 or -0) is preserved.
+def Arith_DenormalModePreserveSign  : I32EnumAttrCase<"preserve_sign", 1>;
+
+// A mode where all denormal numbers are flushed to positive zero (+0), 
+// ignoring the sign of the original number.
+def Arith_DenormalModePositiveZero  : I32EnumAttrCase<"positive_zero",  2>;
+
+def Arith_DenormalMode : I32EnumAttr<
+    "DenormalMode", "denormal mode arith",
+    [Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign, 
+     Arith_DenormalModePositiveZero]> {
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+}
+
+def Arith_DenormalModeAttr :
+    EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 19a5e13a5d755d..4069e43af82e8e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -61,26 +61,35 @@ class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
 // Base class for floating point unary operations.
 class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_UnaryOp<mnemonic,
-      !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
+      !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                   DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
                   traits)>,
     Arguments<(ins FloatLike:$operand,
       DefaultValuedAttr<
-        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+      DefaultValuedAttr<
+        Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
     Results<(outs FloatLike:$result)> {
   let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+                          (`denormal` `` $denormal^)?
                           attr-dict `:` type($result) }];
 }
 
 // Base class for floating point binary operations.
 class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic,
-      !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
+      !listconcat([Pure, 
+                   DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                   DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
                   traits)>,
     Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
       DefaultValuedAttr<
-        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+      DefaultValuedAttr<
+        Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
     Results<(outs FloatLike:$result)> {
-  let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
+  let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)? 
+                          (`denormal` `` $denormal^)?
                           attr-dict `:` type($result) }];
 }
 
@@ -1085,7 +1094,6 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
   let hasFolder = 1;
 }
 
-
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
@@ -1111,8 +1119,6 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
     %x = arith.mulf %y, %z : tensor<4x?xbf16>
     ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..270d80f2ec73af 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -45,13 +45,12 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
         return "fastmath";
       }]
       >
-
   ];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
   let description = [{
-    Access to op integer overflow flags.
+    Access to operation integer overflow flags.
   }];
 
   let cppNamespace = "::mlir::arith";
@@ -108,7 +107,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
 
 def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
   let description = [{
-    Access to op rounding mode.
+    Access to operation rounding mode.
   }];
 
   let cppNamespace = "::mlir::arith";
@@ -139,4 +138,39 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
   ];
 }
 
+
+def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> {
+  let description = [{
+    Access the operation denormal modes.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a DenormalModeAttr attribute for the operation",
+      /*returnType=*/  "DenormalModeAttr",
+      /*methodName=*/  "getDenormalModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getDenormalAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the DenormalModeAttr attribute for
+                         the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getDenormalModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "denormal";
+      }]
+    >
+  ];
+}
+
+
 #endif // ARITH_OPS_INTERFACES

diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..226afb9ad25f1a 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -433,6 +433,12 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() {
   }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat with denormal
+/// values.
+inline detail::constant_float_predicate_matcher m_isDenormalFloat() {
+  return {[](const APFloat &value) { return value.isDenormal(); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer zero.
 inline detail::constant_int_predicate_matcher m_Zero() {
   return {[](const APInt &value) { return 0 == value; }};

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index aac24f113d891f..54d941ae9f6c89 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -53,13 +53,40 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+template <typename SourceOp, typename TargetOp,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough>
+struct DenormalOpConversionToLLVMPattern
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                   AttrConvert>::VectorConvertToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO: Here, we need a legalization step. LLVM provides a function-level
+    // attribute for denormal; here, we need to move this information from the
+    // operation to the function, making sure all the operations in the same
+    // function are consistent.
+    if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee)
+      return rewriter.notifyMatchFailure(
+          op, "only ieee denormal mode is supported at the moment");
+
+    StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName();
+    op->removeAttr(arithDenormalAttrName);
+    return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                      AttrConvert>::matchAndRewrite(op, adaptor,
+                                                                    rewriter);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
 
 using AddFOpLowering =
-    VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -67,8 +94,8 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
-    VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -83,38 +110,38 @@ using FPToSIOpLowering =
 using FPToUIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
 using MaximumFOpLowering =
-    VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using MaxNumFOpLowering =
-    VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
 using MinimumFOpLowering =
-    VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using MinNumFOpLowering =
-    VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
-    VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
                                arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
-    VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
-    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -131,8 +158,8 @@ using ShRUIOpLowering =
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
-    VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                               arith::AttrConvertFastMathToLLVM>;
+    DenormalOpConversionToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                                      arith::AttrConvertFastMathToLLVM>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd7..22c34b2bd42f58 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -422,10 +422,11 @@ def TruncIShrUIMulIToMulUIExtended :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and denormal mode of the original divf)
 def MulFOfNegF :
-    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_MulFOp $x, $y, $fmf),
+    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_), 
+                      (Arith_NegFOp $y, $_, $_), $fmf, $mode),
+        (Arith_MulFOp $x, $y, $fmf, $mode),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 //===----------------------------------------------------------------------===//
@@ -433,10 +434,11 @@ def MulFOfNegF :
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and denormal mode of the original divf)
 def DivFOfNegF :
-    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_DivFOp $x, $y, $fmf),
+    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_), 
+                      (Arith_NegFOp $y, $_, $_), $fmf, $mode),
+        (Arith_DivFOp $x, $y, $fmf, $mode),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 #endif // ARITH_PATTERNS

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..47766f36ad05cf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -952,7 +952,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
-  /// negf(negf(x)) -> x
+  // negf(negf(x)) -> x
   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
     return op.getOperand();
   return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
@@ -982,6 +982,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
     return getLhs();
 
+  // Simplifies subf(x, rhs) to x if the following conditions are met:
+  // 1. `rhs` is a denormal floating-point value.
+  // 2. The denormal mode for the operation is set to positive zero.
+  bool isPositiveZeroMode =
+      getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
+  if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
+    return getLhs();
+
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) { return a - b; });

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..d9840e3923c4f7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1498,15 +1498,17 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {
 
 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
   SmallVector<StringRef> elidedAttrs;
-  std::string attrToElide;
   p << " { " << payloadOp->getName().getStringRef();
   for (const auto &attr : payloadOp->getAttrs()) {
-    auto fastAttr =
-        llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
-    if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
-      attrToElide = attr.getName().str();
-      elidedAttrs.push_back(attrToElide);
-      break;
+    if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
+      if (fastAttr.getValue() == arith::FastMathFlags::none) {
+        elidedAttrs.push_back(attr.getName());
+      }
+    }
+    if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
+      if (denormAttr.getValue() == arith::DenormalMode::ieee) {
+        elidedAttrs.push_back(attr.getName());
+      }
     }
   }
   p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 15a3a1fb50dc9e..fa3b0d894c995b 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -320,7 +320,7 @@ int collectStats(MlirOperation operation) {
   // clang-format off
   // CHECK-LABEL: @stats
   // CHECK: Number of operations: 12
-  // CHECK: Number of attributes: 5
+  // CHECK: Number of attributes: 6
   // CHECK: Number of blocks: 3
   // CHECK: Number of regions: 3
   // CHECK: Number of values: 9

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..f56bf0980b13c1 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3189,3 +3189,26 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
     }
   }
 #-}
+
+// -----
+
+// CHECK-LABEL: @test_fold_denorm
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_fold_denorm(%arg0: f32) -> f32 {
+  // CHECK-NOT: arith.subf
+  // CHECK: return %[[ARG0]] : f32
+  %c_denorm = arith.constant 1.4e-45 : f32
+  %sub = arith.subf %arg0, %c_denorm denormal<positive_zero> : f32
+  return %sub : f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_expect_not_to_fold_denorm
+func.func @test_expect_not_to_fold_denorm(%arg0: f32, %arg1 : f32) -> (f32, f32) {
+  // CHECK-COUNT-2: arith.subf
+  %c_denorm = arith.constant 1.4e-45 : f32
+  %sub = arith.subf %arg0, %c_denorm denormal<ieee> : f32
+  %sub_1 = arith.subf %arg1, %c_denorm denormal<preserve_sign> : f32
+  return %sub, %sub_1 : f32, f32
+}

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..5892e2a3d078c7 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1161,3 +1161,75 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// CHECK-LABEL: check_denorm_modes(
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+func.func @check_denorm_modes(%arg0: f32, %arg1: f32, %arg2: f32) {
+  // CHECK: %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+  %c_denorm = arith.constant 1.4e-45 : f32
+  // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %sub_preserve_sign = arith.subf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %sub_positive_zero = arith.subf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.subf %[[ARG2]], %[[CST]] : f32
+  %sub_ieee = arith.subf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %add_preserve_sign = arith.addf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %add_positive_zero = arith.addf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.addf %[[ARG2]], %[[CST]] : f32
+  %add_ieee = arith.addf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %mul_preserve_sign = arith.mulf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %mul_positive_zero = arith.mulf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.mulf %[[ARG2]], %[[CST]] : f32
+  %mul_ieee = arith.mulf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %div_preserve_sign = arith.divf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %div_positive_zero = arith.divf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.divf %[[ARG2]], %[[CST]] : f32
+  %div_ieee = arith.divf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %maximumf_preserve_sign = arith.maximumf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %maximumf_positive_zero = arith.maximumf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.maximumf %[[ARG2]], %[[CST]] : f32
+  %maximumf_ieee = arith.maximumf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %maxnumf_preserve_sign = arith.maxnumf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %maxnumf_positive_zero = arith.maxnumf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.maxnumf %[[ARG2]], %[[CST]] : f32
+  %maxnumf_ieee = arith.maxnumf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %minimumf_preserve_sign = arith.minimumf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %minimumf_positive_zero = arith.minimumf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.minimumf %[[ARG2]], %[[CST]] : f32
+  %minimumf_ieee = arith.minimumf %arg2, %c_denorm denormal<ieee> : f32
+
+  // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal<preserve_sign> : f32
+  %minnumf_preserve_sign = arith.minnumf %arg0, %c_denorm denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal<positive_zero> : f32
+  %minnumf_positive_zero = arith.minnumf %arg1, %c_denorm denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.minnumf %[[ARG2]], %[[CST]] : f32
+  %minnumf_ieee = arith.minnumf %arg2, %c_denorm denormal<ieee> : f32
+
+
+  // CHECK: %{{.+}} = arith.negf %{{.+}} denormal<preserve_sign> : f32
+  %negf_preserve_sign = arith.negf %arg0 denormal<preserve_sign> : f32
+  // CHECK: %{{.+}} = arith.negf %{{.+}} denormal<positive_zero> : f32
+  %negf_positive_sign = arith.negf %arg0 denormal<positive_zero> : f32
+  // CHECK: %{{.+}} = arith.negf %[[ARG0]] : f32
+  %negf_ieee = arith.negf %arg0 denormal<ieee> : f32
+
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..e3b6958cfa8816 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
 // -----
 
 func.func @generic(%arg0: memref<?x?xf32>) {
-  // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
+  // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{denormal = #arith.denormal<ieee>, fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
   linalg.generic  {
     indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
     iterator_types = ["parallel", "parallel"]}


        


More information about the Mlir-commits mailing list