[Mlir-commits] [mlir] [mlir][arith] Add overflow flags support to arith ops (PR #78376)

Jacques Pienaar llvmlistbot at llvm.org
Tue Jan 16 17:55:57 PST 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/78376

>From c2033ade77ed55cabde9b4568e6e0259166693a7 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 9 Jan 2024 23:17:36 +0100
Subject: [PATCH 1/2] [mlir][arith] Add overflow flags support to arith ops

Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward #77211, #77700 and #77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

Co-authored-by: Ivan Butygin <ivan.butygin at gmail.com>, Yi Wu <yi.wu2 at arm.com>
---
 .../ArithCommon/AttrToLLVMConverter.h         |  47 +++++++-
 .../mlir/Dialect/Arith/IR/ArithBase.td        |  23 ++++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 101 ++++++++++++++----
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    |  57 ++++++++++
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     |   2 +-
 .../ArithCommon/AttrToLLVMConverter.cpp       |  29 ++++-
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    |  12 ++-
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  |  59 +++++++++-
 .../Dialect/Arith/IR/ArithCanonicalization.td |  94 +++++++++-------
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |   5 +
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir |  13 +++
 .../ArithToSPIRV/arith-to-spirv.mlir          |  40 +++++++
 mlir/test/Dialect/Arith/ops.mlir              |  11 ++
 mlir/test/python/dialects/arith_llvm.py       |  28 +++++
 mlir/test/python/ir/diagnostic_handler.py     |   2 +-
 15 files changed, 446 insertions(+), 77 deletions(-)
 create mode 100644 mlir/test/python/dialects/arith_llvm.py

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index eea16b4da6a6900..32d7979c32dfb2c 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -18,14 +18,24 @@
 
 namespace mlir {
 namespace arith {
-// Map arithmetic fastmath enum values to LLVMIR enum values.
+/// Maps arithmetic fastmath enum values to LLVM enum values.
 LLVM::FastmathFlags
 convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
 
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
+/// attribute.
 LLVM::FastmathFlagsAttr
 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 
+/// Maps arithmetic overflow enum values to LLVM enum values.
+LLVM::IntegerOverflowFlags
+convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
+
+/// Creates an LLVM overflow attribute from a given arithmetic overflow
+/// attribute.
+LLVM::IntegerOverflowFlagsAttr
+convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+
 // Attribute converter that populates a NamedAttrList by removing the fastmath
 // attribute from the source operation attributes, and replacing it with an
 // equivalent LLVM fastmath attribute.
@@ -36,12 +46,12 @@ class AttrConvertFastMathToLLVM {
     // Copy the source attributes.
     convertedAttr = NamedAttrList{srcOp->getAttrs()};
     // Get the name of the arith fastmath attribute.
-    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
     // Remove the source fastmath attribute.
-    auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
+    auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
         convertedAttr.erase(arithFMFAttrName));
     if (arithFMFAttr) {
-      llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+      StringRef targetAttrName = TargetOp::getFastmathAttrName();
       convertedAttr.set(targetAttrName,
                         convertArithFastMathAttrToLLVM(arithFMFAttr));
     }
@@ -49,6 +59,33 @@ class AttrConvertFastMathToLLVM {
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
 
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the overflow
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM overflow attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertOverflowToLLVM {
+public:
+  AttrConvertOverflowToLLVM(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith overflow attribute.
+    StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+    // Remove the source overflow attribute.
+    auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
+        convertedAttr.erase(arithAttrName));
+    if (arithAttr) {
+      StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+      convertedAttr.set(targetAttrName,
+                        convertArithOverflowAttrToLLVM(arithAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
 private:
   NamedAttrList convertedAttr;
 };
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 1e4061392b22d48..c8a42c43c880b09 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// Arith_IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def Arith_IOFnone : I32BitEnumAttrCaseNone<"none">;
+def Arith_IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def Arith_IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def Arith_IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "Integer overflow arith flags",
+    [Arith_IOFnone, Arith_IOFnsw, Arith_IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def Arith_IntegerOverflowAttr :
+    EnumAttr<Arith_Dialect, Arith_IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6d133d69dd0faff..cd0102f91ef1523 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
   let results = (outs BoolLikeOfAnyRank:$result);
 }
 
+class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic, traits #
+      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+       DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
+    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+      DefaultValuedAttr<
+        Arith_IntegerOverflowAttr,
+        "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
+    Results<(outs SignlessIntegerLike:$result)> {
+
+  let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+                          attr-dict `:` type($result) }];
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
     Performs N-bit addition on the operands. The operands are interpreted as 
@@ -203,8 +217,12 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
 
     The `addi` operation takes two operands and returns one result, each of
     these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
 
     Example:
 
@@ -212,7 +230,10 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
     // Scalar addition.
     %a = arith.addi %b, %c : i64
 
-    // SIMD vector element-wise addition, e.g. for Intel SSE.
+    // Scalar addition with overflow flags.
+    %a = arith.addi %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise addition.
     %f = arith.addi %g, %h : vector<4xi32>
 
     // Tensor element-wise addition.
@@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
+def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
   let summary = [{
     Integer subtraction operation.
   }];
   let description = [{
-    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned 
-    bitvectors. The result is represented by a bitvector containing the mathematical 
-    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith` 
-    integers use a two's complement representation, this operation is applicable on 
+    Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
+    bitvectors. The result is represented by a bitvector containing the mathematical
+    value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
+    integers use a two's complement representation, this operation is applicable on
     both signed and unsigned integer operands.
 
     The `subi` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
+
+    Example:
+
+    ```mlir
+    // Scalar subtraction.
+    %a = arith.subi %b, %c : i64
+
+    // Scalar subtraction with overflow flags.
+    %a = arith.subi %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise subtraction.
+    %f = arith.subi %g, %h : vector<4xi32>
+
+    // Tensor element-wise subtraction.
+    %x = arith.subi %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
   let summary = [{
     Integer multiplication operation.
   }];
   let description = [{
-    Performs N-bit multiplication on the operands. The operands are interpreted as 
-    unsigned bitvectors. The result is represented by a bitvector containing the 
-    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth. 
-    Because `arith` integers use a two's complement representation, this operation is 
+    Performs N-bit multiplication on the operands. The operands are interpreted as
+    unsigned bitvectors. The result is represented by a bitvector containing the
+    mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
+    Because `arith` integers use a two's complement representation, this operation is
     applicable on both signed and unsigned integer operands.
 
     The `muli` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
-    a vector whose element type is integer, or a tensor of integers. It has no 
-    standard attributes.    
+    these is required to be the same type. This type may be an integer scalar type,
+    a vector whose element type is integer, or a tensor of integers.
+
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
+
+    Example:
+
+    ```mlir
+    // Scalar multiplication.
+    %a = arith.muli %b, %c : i64
+
+    // Scalar multiplication with overflow flags.
+    %a = arith.muli %b, %c overflow<nsw, nuw> : i64
+
+    // SIMD vector element-wise multiplication.
+    %f = arith.muli %g, %h : vector<4xi32>
+
+    // Tensor element-wise multiplication.
+    %x = arith.muli %y, %z : tensor<4x?xi8>
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index acaecf6f409dcff..73a5d9c32ef2057 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
   ];
 }
 
+def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getOverflowFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoUnsignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoSignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerOverflowFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerOverflowAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "overflowFlags";
+      }]
+      >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 81589eaf5fd0a43..3b2a132a881e4ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -92,7 +92,7 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
       }]
       >,
     StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
+      /*desc=*/        [{Returns the name of the IntegerOverflowFlagsAttr attribute
                          for the operation}],
       /*returnType=*/  "StringRef",
       /*methodName=*/  "getIntegerOverflowAttrName",
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 8c5d76f9f2d72e5..dab064a3a954ec5 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -10,7 +10,6 @@
 
 using namespace mlir;
 
-// Map arithmetic fastmath enum values to LLVMIR enum values.
 LLVM::FastmathFlags
 mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
   LLVM::FastmathFlags llvmFMF{};
@@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
       {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
       {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
       {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
-  for (auto fmfMap : flags) {
-    if (bitEnumContainsAny(arithFMF, fmfMap.first))
-      llvmFMF = llvmFMF | fmfMap.second;
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFMF, arithFlag))
+      llvmFMF = llvmFMF | llvmFlag;
   }
   return llvmFMF;
 }
 
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
 LLVM::FastmathFlagsAttr
 mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
   arith::FastMathFlags arithFMF = fmfAttr.getValue();
   return LLVM::FastmathFlagsAttr::get(
       fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
 }
+
+LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
+    arith::IntegerOverflowFlags arithFlags) {
+  LLVM::IntegerOverflowFlags llvmFlags{};
+  const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
+      flags[] = {
+          {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
+          {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFlags, arithFlag))
+      llvmFlags = llvmFlags | llvmFlag;
+  }
+  return llvmFlags;
+}
+
+LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
+    arith::IntegerOverflowFlagsAttr flagsAttr) {
+  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
+  return LLVM::IntegerOverflowFlagsAttr::get(
+      flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
+}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5e4213cc4e874a0..cf46e0d3ac46ac3 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -35,7 +35,9 @@ namespace {
 using AddFOpLowering =
     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
                                arith::AttrConvertFastMathToLLVM>;
-using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
+using AddIOpLowering =
+    VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
@@ -78,7 +80,9 @@ using MinUIOpLowering =
 using MulFOpLowering =
     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
                                arith::AttrConvertFastMathToLLVM>;
-using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
+using MulIOpLowering =
+    VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
                                arith::AttrConvertFastMathToLLVM>;
@@ -102,7 +106,9 @@ using SIToFPOpLowering =
 using SubFOpLowering =
     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
                                arith::AttrConvertFastMathToLLVM>;
-using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
+using SubIOpLowering =
+    VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
 using TruncIOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index aba6a21deccb0cf..1abad1e9fa4d85c 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -158,8 +158,61 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
   return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
 }
 
+// TODO: Move to some common place?
+static std::string getDecorationString(spirv::Decoration decor) {
+  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
+}
+
 namespace {
 
+/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
+/// operations. Op can potentially support overflow flags.
+template <typename Op, typename SPIRVOp>
+struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() <= 3);
+    auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+    }
+
+    if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+        !getElementTypeOrSelf(op.getType()).isIndex() &&
+        dstType != op.getType()) {
+      return op.emitError("bitwidth emulation is not implemented yet on "
+                          "unsigned op pattern version");
+    }
+
+    auto overflowFlags = arith::IntegerOverflowFlags::none;
+    if (auto overflowIface =
+            dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
+      if (converter->getTargetEnv().allows(
+              spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
+        overflowFlags = overflowIface.getOverflowAttr().getValue();
+    }
+
+    auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+        op, dstType, adaptor.getOperands());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
+                     rewriter.getUnitAttr());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
+                     rewriter.getUnitAttr());
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -1154,9 +1207,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
-    spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
-    spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
+    ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ef951647ccd1464..18ceeb0054045e8 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,12 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+// TODO: Canonicalizations currently doesn't take into account integer overflow
+// flags and always reset them to default (wraparound) which is safe but can
+// inhibit later optimizations. Individual patterns must be reviewed for
+// better handling of overflow flags.
+def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -36,23 +42,26 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
+            (DefOverflow))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -63,24 +72,25 @@ def IsScalarOrSplatNegativeOne :
 def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
-           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
-        (Arith_SubIOp $x, $y),
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp $x, $y, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
 def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
-           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
-           $y),
-        (Arith_SubIOp $y, $x),
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
+           $y, $ovf2),
+        (Arith_SubIOp $y, $x, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
 def MulIMulIConstant :
     Pat<(Arith_MulIOp:$res
-          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -90,7 +100,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -100,49 +110,55 @@ def AddUIExtendedToAddI:
 // subi(addi(x, c0), c1) -> addi(x, c0 - c1)
 def SubIRHSAddConstant :
     Pat<(Arith_SubIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
+            (DefOverflow))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
+            (DefOverflow))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
+            (DefOverflow))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
+            (DefOverflow))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
-    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -152,7 +168,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -179,7 +195,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -403,7 +419,7 @@ def TruncIShrSIToTrunciShrUI :
 def TruncIShrUIMulIToMulSIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
+                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
@@ -414,7 +430,7 @@ def TruncIShrUIMulIToMulSIExtended :
 def TruncIShrUIMulIToMulUIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
-                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
+                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
                               (ConstantLikeMatcher AnyAttr:$c0))),
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa776..2d124ce4980fa46 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,6 +61,11 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
+  return IntegerOverflowFlagsAttr::get(builder.getContext(),
+                                       IntegerOverflowFlags::none);
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e16dbb5661058cc..8937b24e0d174d1 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -575,3 +575,16 @@ func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   %7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
   return
 }
+
+// -----
+
+// CHECK-LABEL: @ops_supporting_overflow
+func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..8bf90ed0aec8ee1 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module
+
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
+} {
+
+// No decorations should be generated is corresponding Extensions/Capabilities are missing
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 6e10e540d1d178e..8ae3273f32c6b02 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1138,3 +1138,14 @@ func.func @select_tensor_encoding(
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
   return %0 : tensor<8xi32, "foo">
 }
+
+// CHECK-LABEL: @intflags_func
+func.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = arith.subi %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
diff --git a/mlir/test/python/dialects/arith_llvm.py b/mlir/test/python/dialects/arith_llvm.py
new file mode 100644
index 000000000000000..83e3eb2c98fcdc1
--- /dev/null
+++ b/mlir/test/python/dialects/arith_llvm.py
@@ -0,0 +1,28 @@
+# RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
+
+from mlir.ir import *
+import mlir.dialects.arith as arith
+import mlir.dialects.func as func
+from mlir.dialects import llvm
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    f()
+
+
+# CHECK-LABEL: TEST: testOverflowFlags
+# Test mostly to repro and verify error addressed for Python bindings.
+ at run
+def testOverflowFlags():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            a = arith.ConstantOp(value=42, result=IntegerType.get_signless(32))
+            r = arith.AddIOp(
+                a, a, overflowFlags=arith.IntegerOverflowFlags.nsw
+            )
+            # CHECK: arith.addi {{.*}}, {{.*}} overflow<nsw> : i32
+            print(r)
+
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index 2f4300d2c55dfa3..d516cda8198976e 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -113,7 +113,7 @@ def testDiagnosticNonEmptyNotes():
     def callback(d):
         # CHECK: DIAGNOSTIC:
         # CHECK:   message='arith.addi' op requires one result
-        # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
+        # CHECK:   notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
         print(f"DIAGNOSTIC:")
         print(f"  message={d.message}")
         print(f"  notes={list(map(str, d.notes))}")

>From dd016103540b38f650d09064427f122138c95f23 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 17 Jan 2024 01:55:46 +0000
Subject: [PATCH 2/2] Fix lint issues identified by darker

---
 mlir/test/python/dialects/arith_llvm.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/test/python/dialects/arith_llvm.py b/mlir/test/python/dialects/arith_llvm.py
index 83e3eb2c98fcdc1..e8eaf3959a23e59 100644
--- a/mlir/test/python/dialects/arith_llvm.py
+++ b/mlir/test/python/dialects/arith_llvm.py
@@ -20,9 +20,7 @@ def testOverflowFlags():
         module = Module.create()
         with InsertionPoint(module.body):
             a = arith.ConstantOp(value=42, result=IntegerType.get_signless(32))
-            r = arith.AddIOp(
-                a, a, overflowFlags=arith.IntegerOverflowFlags.nsw
-            )
+            r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
             # CHECK: arith.addi {{.*}}, {{.*}} overflow<nsw> : i32
             print(r)
 



More information about the Mlir-commits mailing list