[Mlir-commits] [mlir] a7262d2 - [mlir][arith] Add overflow flags support to arith ops (#77211)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 9 14:17:40 PST 2024


Author: Ivan Butygin
Date: 2024-01-10T01:17:36+03:00
New Revision: a7262d2d9bee9bdfdbcd03ca27a0128c2e2b1c1a

URL: https://github.com/llvm/llvm-project/commit/a7262d2d9bee9bdfdbcd03ca27a0128c2e2b1c1a
DIFF: https://github.com/llvm/llvm-project/commit/a7262d2d9bee9bdfdbcd03ca27a0128c2e2b1c1a.diff

LOG: [mlir][arith] Add overflow flags support to arith ops (#77211)

Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
``` 

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

---------

Co-authored-by: Yi Wu <yi.wu2 at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
    mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
    mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Dialect/Arith/ops.mlir
    mlir/test/python/ir/diagnostic_handler.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index eea16b4da6a690..0296ec969d0b32 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -18,14 +18,24 @@
 
 namespace mlir {
 namespace arith {
-// Map arithmetic fastmath enum values to LLVMIR enum values.
+/// Maps arithmetic fastmath enum values to LLVM enum values.
 LLVM::FastmathFlags
 convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
 
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
+/// attribute.
 LLVM::FastmathFlagsAttr
 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 
+/// Maps arithmetic overflow enum values to LLVM enum values.
+LLVM::IntegerOverflowFlags
+convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
+
+/// Creates an LLVM overflow attribute from a given arithmetic overflow
+/// attribute.
+LLVM::IntegerOverflowFlagsAttr
+convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+
 // Attribute converter that populates a NamedAttrList by removing the fastmath
 // attribute from the source operation attributes, and replacing it with an
 // equivalent LLVM fastmath attribute.
@@ -36,12 +46,12 @@ class AttrConvertFastMathToLLVM {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
     // Get the name of the arith fastmath attribute.
-    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
     // Remove the source fastmath attribute.
-    auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
+    auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
         convertedAttr.erase(arithFMFAttrName));
     if (arithFMFAttr) {
-      llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+      StringRef targetAttrName = TargetOp::getFastmathAttrName();
       convertedAttr.set(targetAttrName,
                         convertArithFastMathAttrToLLVM(arithFMFAttr));
     }
@@ -49,6 +59,33 @@ class AttrConvertFastMathToLLVM {
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
 
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the overflow
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM overflow attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertOverflowToLLVM {
+public:
+  AttrConvertOverflowToLLVM(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith overflow attribute.
+    StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+    // Remove the source overflow attribute.
+    auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
+        convertedAttr.erase(arithAttrName));
+    if (arithAttr) {
+      StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+      convertedAttr.set(targetAttrName,
+                        convertArithOveflowAttrToLLVM(arithAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
 private:
   NamedAttrList convertedAttr;
 };

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 1e4061392b22d4..3fb7f948b0a45a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "Integer overflow arith flags",
+    [IOFnone, IOFnsw, IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def Arith_IntegerOverflowAttr :
+    EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6d133d69dd0faf..cd0102f91ef152 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
   let results = (outs BoolLikeOfAnyRank:$result);
 }
 
+class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic, traits #
+      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+       DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
+    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+      DefaultValuedAttr<
+        Arith_IntegerOverflowAttr,
+        "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
+    Results<(outs SignlessIntegerLike:$result)> {
+
+  let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+                          attr-dict `:` type($result) }];
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
     Performs N-bit addition on the operands. The operands are interpreted as 
@@ -203,8 +217,12 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
 
     The `addi` operation takes two operands and returns one result, each of
     these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
 
     Example:
 
@@ -212,7 +230,10 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
     // Scalar addition.
     %a = arith.addi %b, %c : i64
 
-    // SIMD vector element-wise addition, e.g. for Intel SSE.
+    // Scalar addition with overflow flags.
+    %a = arith.addi %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise addition.
     %f = arith.addi %g, %h : vector<4xi32>
 
     // Tensor element-wise addition.
@@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
+def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
   let summary = [{
     Integer subtraction operation.
   }];
   let description = [{
-    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned 
-    bitvectors. The result is represented by a bitvector containing the mathematical 
-    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith` 
-    integers use a two's complement representation, this operation is applicable on 
+    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
+    bitvectors. The result is represented by a bitvector containing the mathematical
+    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
+    integers use a two's complement representation, this operation is applicable on
     both signed and unsigned integer operands.
 
     The `subi` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
+
+    Example:
+
+    ```mlir
+    // Scalar subtraction.
+    %a = arith.subi %b, %c : i64
+
+    // Scalar subtraction with overflow flags.
+    %a = arith.subi %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise subtraction.
+    %f = arith.subi %g, %h : vector<4xi32>
+
+    // Tensor element-wise subtraction.
+    %x = arith.subi %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
   let summary = [{
     Integer multiplication operation.
   }];
   let description = [{
-    Performs N-bit multiplication on the operands. The operands are interpreted as 
-    unsigned bitvectors. The result is represented by a bitvector containing the 
-    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth. 
-    Because `arith` integers use a two's complement representation, this operation is 
+    Performs N-bit multiplication on the operands. The operands are interpreted as
+    unsigned bitvectors. The result is represented by a bitvector containing the
+    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
+    Because `arith` integers use a two's complement representation, this operation is
     applicable on both signed and unsigned integer operands.
 
     The `muli` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.    
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
+
+    Example:
+
+    ```mlir
+    // Scalar multiplication.
+    %a = arith.muli %b, %c : i64
+
+    // Scalar multiplication with overflow flags.
+    %a = arith.muli %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise multiplication.
+    %f = arith.muli %g, %h : vector<4xi32>
+
+    // Tensor element-wise multiplication.
+    %x = arith.muli %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index acaecf6f409dcf..e248422f84db84 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
   ];
 }
 
+def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getOverflowFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoUnsignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoSignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerOverflowAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "overflowFlags";
+      }]
+      >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES

diff  --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 8c5d76f9f2d72e..3e9aef87b9ef1f 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -10,7 +10,6 @@
 
 using namespace mlir;
 
-// Map arithmetic fastmath enum values to LLVMIR enum values.
 LLVM::FastmathFlags
 mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
   LLVM::FastmathFlags llvmFMF{};
@@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
       {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
       {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
       {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
-  for (auto fmfMap : flags) {
-    if (bitEnumContainsAny(arithFMF, fmfMap.first))
-      llvmFMF = llvmFMF | fmfMap.second;
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFMF, arithFlag))
+      llvmFMF = llvmFMF | llvmFlag;
   }
   return llvmFMF;
 }
 
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
 LLVM::FastmathFlagsAttr
 mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
   arith::FastMathFlags arithFMF = fmfAttr.getValue();
   return LLVM::FastmathFlagsAttr::get(
       fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
 }
+
+LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
+    arith::IntegerOverflowFlags arithFlags) {
+  LLVM::IntegerOverflowFlags llvmFlags{};
+  const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
+      flags[] = {
+          {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
+          {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFlags, arithFlag))
+      llvmFlags = llvmFlags | llvmFlag;
+  }
+  return llvmFlags;
+}
+
+LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
+    arith::IntegerOverflowFlagsAttr flagsAttr) {
+  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
+  return LLVM::IntegerOverflowFlagsAttr::get(
+      flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
+}

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5e4213cc4e874a..cf46e0d3ac46ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -35,7 +35,9 @@ namespace {
 using AddFOpLowering =
     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
                                arith::AttrConvertFastMathToLLVM>;
-using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
+using AddIOpLowering =
+    VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
@@ -78,7 +80,9 @@ using MinUIOpLowering =
 using MulFOpLowering =
     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
                                arith::AttrConvertFastMathToLLVM>;
-using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
+using MulIOpLowering =
+    VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
                                arith::AttrConvertFastMathToLLVM>;
@@ -102,7 +106,9 @@ using SIToFPOpLowering =
 using SubFOpLowering =
     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
                                arith::AttrConvertFastMathToLLVM>;
-using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
+using SubIOpLowering =
+    VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
 using TruncIOpLowering =

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ef951647ccd146..18ceeb0054045e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,12 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+// TODO: Canonicalizations currently doesn't take into account integer overflow
+// flags and always reset them to default (wraparound) which is safe but can
+// inhibit later optimizations. Individual patterns must be reviewed for
+// better handling of overflow flags.
+def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -36,23 +42,26 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
+            (DefOverflow))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -63,24 +72,25 @@ def IsScalarOrSplatNegativeOne :
 def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
-           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
-        (Arith_SubIOp $x, $y),
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp $x, $y, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
 def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
-           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
-           $y),
-        (Arith_SubIOp $y, $x),
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
+           $y, $ovf2),
+        (Arith_SubIOp $y, $x, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
 def MulIMulIConstant :
     Pat<(Arith_MulIOp:$res
-          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -90,7 +100,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -100,49 +110,55 @@ def AddUIExtendedToAddI:
 // subi(addi(x, c0), c1) -> addi(x, c0 - c1)
 def SubIRHSAddConstant :
     Pat<(Arith_SubIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
+            (DefOverflow))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
+            (DefOverflow))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
-    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -152,7 +168,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -179,7 +195,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -403,7 +419,7 @@ def TruncIShrSIToTrunciShrUI :
 def TruncIShrUIMulIToMulSIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
+                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
@@ -414,7 +430,7 @@ def TruncIShrUIMulIToMulSIExtended :
 def TruncIShrUIMulIToMulUIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
+                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa77..2d124ce4980fa4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,6 +61,11 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
+  return IntegerOverflowFlagsAttr::get(builder.getContext(),
+                                       IntegerOverflowFlags::none);
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e16dbb5661058c..8937b24e0d174d 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -575,3 +575,16 @@ func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   %7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
   return
 }
+
+// -----
+
+// CHECK-LABEL: @ops_supporting_overflow
+func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 6e10e540d1d178..8ae3273f32c6b0 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1138,3 +1138,14 @@ func.func @select_tensor_encoding(
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
   return %0 : tensor<8xi32, "foo">
 }
+
+// CHECK-LABEL: @intflags_func
+func.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = arith.subi %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}

diff  --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index 2f4300d2c55dfa..d516cda8198976 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -113,7 +113,7 @@ def testDiagnosticNonEmptyNotes():
     def callback(d):
         # CHECK: DIAGNOSTIC:
         # CHECK:   message='arith.addi' op requires one result
-        # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
+        # CHECK:   notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
         print(f"DIAGNOSTIC:")
         print(f"  message={d.message}")
         print(f"  notes={list(map(str, d.notes))}")


        


More information about the Mlir-commits mailing list