[Mlir-commits] [mlir] [mlir][arith] Add overflow flags support to arith ops (PR #77211)

Ivan Butygin llvmlistbot at llvm.org
Sun Jan 7 14:20:26 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/77211

>From 0c1f04dc1ae0667563171147a89859d1231c5b04 Mon Sep 17 00:00:00 2001
From: Yi Wu <yi.wu2 at arm.com>
Date: Fri, 5 Jan 2024 13:57:51 +0000
Subject: [PATCH 1/3] [mlir][arith] Add overflow flags support to arith ops

Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise.
---
 .../ArithCommon/AttrToLLVMConverter.h         | 35 ++++++++
 .../mlir/Dialect/Arith/IR/ArithBase.td        | 23 ++++++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 22 ++++-
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    | 57 +++++++++++++
 .../ArithCommon/AttrToLLVMConverter.cpp       | 29 ++++++-
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 12 ++-
 .../Dialect/Arith/IR/ArithCanonicalization.td | 80 ++++++++++---------
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  5 ++
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 13 +++
 mlir/test/Dialect/Arith/ops.mlir              | 11 +++
 10 files changed, 239 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index eea16b4da6a690..dbd0726fe16d1a 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -26,6 +26,14 @@ convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
 LLVM::FastmathFlagsAttr
 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags
+convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr
+convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+
 // Attribute converter that populates a NamedAttrList by removing the fastmath
 // attribute from the source operation attributes, and replacing it with an
 // equivalent LLVM fastmath attribute.
@@ -49,6 +57,33 @@ class AttrConvertFastMathToLLVM {
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
 
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the overflow
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertOverflowToLLVM {
+public:
+  AttrConvertOverflowToLLVM(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+    // Remove the source fastmath attribute.
+    auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
+        convertedAttr.erase(arithAttrName));
+    if (arithAttr) {
+      llvm::StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+      convertedAttr.set(targetAttrName,
+                        convertArithOveflowAttrToLLVM(arithAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
 private:
   NamedAttrList convertedAttr;
 };
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 1e4061392b22d4..3fb7f948b0a45a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "Integer overflow arith flags",
+    [IOFnone, IOFnsw, IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def Arith_IntegerOverflowAttr :
+    EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6d133d69dd0faf..880718bca9e7ec 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,6 +137,22 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
   let results = (outs BoolLikeOfAnyRank:$result);
 }
 
+class Arith_IntArithmeticOpWithOverflowFlag<string mnemonic, list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic, traits #
+      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+      DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
+    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+      DefaultValuedAttr<
+        Arith_IntegerOverflowAttr, "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
+    Results<(outs SignlessIntegerLike:$result)> {
+
+  let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+                          attr-dict `:` type($result) }];
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -192,7 +208,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
     Performs N-bit addition on the operands. The operands are interpreted as 
@@ -278,7 +294,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
+def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
   let summary = [{
     Integer subtraction operation.
   }];
@@ -302,7 +318,7 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]> {
   let summary = [{
     Integer multiplication operation.
   }];
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index acaecf6f409dcf..e248422f84db84 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
   ];
 }
 
+def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getOverflowFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoUnsignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoSignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerOverflowAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "overflowFlags";
+      }]
+      >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 8c5d76f9f2d72e..7ba12de122bb4d 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -22,9 +22,9 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
       {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
       {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
       {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
-  for (auto fmfMap : flags) {
-    if (bitEnumContainsAny(arithFMF, fmfMap.first))
-      llvmFMF = llvmFMF | fmfMap.second;
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFMF, arithFlag))
+      llvmFMF = llvmFMF | llvmFlag;
   }
   return llvmFMF;
 }
@@ -36,3 +36,26 @@ mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
   return LLVM::FastmathFlagsAttr::get(
       fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
 }
+
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
+    arith::IntegerOverflowFlags arithFlags) {
+  LLVM::IntegerOverflowFlags llvmFlags{};
+  const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
+      flags[] = {
+          {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
+          {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFlags, arithFlag))
+      llvmFlags = llvmFlags | llvmFlag;
+  }
+  return llvmFlags;
+}
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
+    arith::IntegerOverflowFlagsAttr flagsAttr) {
+  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
+  return LLVM::IntegerOverflowFlagsAttr::get(
+      flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
+}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5e4213cc4e874a..cf46e0d3ac46ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -35,7 +35,9 @@ namespace {
 using AddFOpLowering =
     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
                                arith::AttrConvertFastMathToLLVM>;
-using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
+using AddIOpLowering =
+    VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
@@ -78,7 +80,9 @@ using MinUIOpLowering =
 using MulFOpLowering =
     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
                                arith::AttrConvertFastMathToLLVM>;
-using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
+using MulIOpLowering =
+    VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
                                arith::AttrConvertFastMathToLLVM>;
@@ -102,7 +106,9 @@ using SIToFPOpLowering =
 using SubFOpLowering =
     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
                                arith::AttrConvertFastMathToLLVM>;
-using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
+using SubIOpLowering =
+    VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
 using TruncIOpLowering =
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ef951647ccd146..19f0c0aac31713 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,8 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -36,23 +38,23 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -63,24 +65,24 @@ def IsScalarOrSplatNegativeOne :
 def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
-           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
-        (Arith_SubIOp $x, $y),
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp $x, $y, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
 def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
-           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
-           $y),
-        (Arith_SubIOp $y, $x),
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
+           $y, $ovf2),
+        (Arith_SubIOp $y, $x, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
 def MulIMulIConstant :
     Pat<(Arith_MulIOp:$res
-          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -90,7 +92,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -100,49 +102,49 @@ def AddUIExtendedToAddI:
 // subi(addi(x, c0), c1) -> addi(x, c0 - c1)
 def SubIRHSAddConstant :
     Pat<(Arith_SubIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x, (DefOverflow))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
-    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -152,7 +154,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -179,7 +181,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -403,7 +405,7 @@ def TruncIShrSIToTrunciShrUI :
 def TruncIShrUIMulIToMulSIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
+                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
@@ -414,7 +416,7 @@ def TruncIShrUIMulIToMulSIExtended :
 def TruncIShrUIMulIToMulUIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
+                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa77..2d124ce4980fa4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,6 +61,11 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
+  return IntegerOverflowFlagsAttr::get(builder.getContext(),
+                                       IntegerOverflowFlags::none);
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e16dbb5661058c..8937b24e0d174d 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -575,3 +575,16 @@ func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   %7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
   return
 }
+
+// -----
+
+// CHECK-LABEL: @ops_supporting_overflow
+func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 6e10e540d1d178..8ae3273f32c6b0 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1138,3 +1138,14 @@ func.func @select_tensor_encoding(
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
   return %0 : tensor<8xi32, "foo">
 }
+
+// CHECK-LABEL: @intflags_func
+func.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = arith.subi %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}

>From 82eeda06eaf5bfdf3d5efc314d7faf102619291e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 6 Jan 2024 23:53:14 +0100
Subject: [PATCH 2/3] fix python test

---
 mlir/test/python/ir/diagnostic_handler.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index 2f4300d2c55dfa..d516cda8198976 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -113,7 +113,7 @@ def testDiagnosticNonEmptyNotes():
     def callback(d):
         # CHECK: DIAGNOSTIC:
         # CHECK:   message='arith.addi' op requires one result
-        # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
+        # CHECK:   notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
         print(f"DIAGNOSTIC:")
         print(f"  message={d.message}")
         print(f"  notes={list(map(str, d.notes))}")

>From 3d09907fe711ece18cf2b6535ec4f0f429ac458a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 7 Jan 2024 23:20:06 +0100
Subject: [PATCH 3/3] Update docs and comments

---
 .../ArithCommon/AttrToLLVMConverter.h         |  6 +-
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 79 +++++++++++++++----
 2 files changed, 66 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index dbd0726fe16d1a..5dff716724a475 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -63,16 +63,16 @@ class AttrConvertFastMathToLLVM {
 
 // Attribute converter that populates a NamedAttrList by removing the overflow
 // attribute from the source operation attributes, and replacing it with an
-// equivalent LLVM fastmath attribute.
+// equivalent LLVM overflow attribute.
 template <typename SourceOp, typename TargetOp>
 class AttrConvertOverflowToLLVM {
 public:
   AttrConvertOverflowToLLVM(SourceOp srcOp) {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
-    // Get the name of the arith fastmath attribute.
+    // Get the name of the arith overflow attribute.
     llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
-    // Remove the source fastmath attribute.
+    // Remove the source overflow attribute.
     auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
         convertedAttr.erase(arithAttrName));
     if (arithAttr) {
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 880718bca9e7ec..0f340fbf748f10 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -219,8 +219,12 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
 
     The `addi` operation takes two operands and returns one result, each of
     these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+    "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+    result value is undefined if unsigned and/or signed overflow, respectively,
+    occurs.
 
     Example:
 
@@ -228,6 +232,9 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
     // Scalar addition.
     %a = arith.addi %b, %c : i64
 
+    // Scalar addition with overflow flags.
+    %a = arith.addi %b, %c overflow<nsw, nuw> : i64
+
     // SIMD vector element-wise addition, e.g. for Intel SSE.
     %f = arith.addi %g, %h : vector<4xi32>
 
@@ -299,16 +306,36 @@ def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
     Integer subtraction operation.
   }];
   let description = [{
-    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned 
-    bitvectors. The result is represented by a bitvector containing the mathematical 
-    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith` 
-    integers use a two's complement representation, this operation is applicable on 
+    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
+    bitvectors. The result is represented by a bitvector containing the mathematical
+    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
+    integers use a two's complement representation, this operation is applicable on
     both signed and unsigned integer operands.
 
     The `subi` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+    "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+    result value is undefined if unsigned and/or signed overflow, respectively,
+    occurs.
+
+    Example:
+
+    ```mlir
+    // Scalar subtraction.
+    %a = arith.subi %b, %c : i64
+
+    // Scalar subtraction with overflow flags.
+    %a = arith.subi %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise subtraction, e.g. for Intel SSE.
+    %f = arith.subi %g, %h : vector<4xi32>
+
+    // Tensor element-wise subtraction.
+    %x = arith.subi %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -323,16 +350,36 @@ def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]>
     Integer multiplication operation.
   }];
   let description = [{
-    Performs N-bit multiplication on the operands. The operands are interpreted as 
-    unsigned bitvectors. The result is represented by a bitvector containing the 
-    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth. 
-    Because `arith` integers use a two's complement representation, this operation is 
+    Performs N-bit multiplication on the operands. The operands are interpreted as
+    unsigned bitvectors. The result is represented by a bitvector containing the
+    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
+    Because `arith` integers use a two's complement representation, this operation is
     applicable on both signed and unsigned integer operands.
 
     The `muli` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.    
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+    "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+    result value is undefined if unsigned and/or signed overflow, respectively,
+    occurs.
+
+    Example:
+
+    ```mlir
+    // Scalar multiplication.
+    %a = arith.muli %b, %c : i64
+
+    // Scalar multiplication with overflow flags.
+    %a = arith.muli %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise multiplication, e.g. for Intel SSE.
+    %f = arith.muli %g, %h : vector<4xi32>
+
+    // Tensor element-wise multiplication.
+    %x = arith.muli %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;



More information about the Mlir-commits mailing list