[Mlir-commits] [mlir] [mlir][LLVM] Add nsw and nuw flags (PR #74508)

Tom Eccles llvmlistbot at llvm.org
Wed Dec 6 04:50:56 PST 2023


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/74508

>From 0c5f09c00af6800d26f71daace89139c29ed36c5 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 23 Nov 2023 17:54:13 +0000
Subject: [PATCH 1/8] [mlir][LLVM] Add nsw and nuw flags

The implementation of these are modelled after the existing fastmath flags for
floating point arithmetic.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 23 ++++++++
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 57 +++++++++++++++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 23 ++++++--
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  5 ++
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 11 +++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 13 +++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 10 ++++
 mlir/test/Target/LLVMIR/Import/nsw_nuw.ll     | 14 +++++
 mlir/test/Target/LLVMIR/nsw_nuw.mlir          | 14 +++++
 9 files changed, 165 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
 create mode 100644 mlir/test/Target/LLVMIR/nsw_nuw.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
   let printBitEnumPrimaryGroups = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+    "IntegerArithFlags",
+    "LLVM integer arithmetic flags",
+    [IAFnone, IAFnsw, IAFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::LLVM";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+    EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // FastmathFlags
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a IntegerArithFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerArithFlagsAttr",
+      /*methodName=*/  "getArithAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getArithFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNuw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNsw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerArithFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerArithAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "arithFlags";
+      }]
+      >
+  ];
+}
+
 def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
   let description = [{
     An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
     $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
   }];
 }
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+  dag iafArg = (
+    ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+  let arguments = !con(commonArgs, iafArg);
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setIntegerFlagsAttr(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
 }
 
 // Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
 def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
 def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
   let hasFolder = 1;
 }
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
   /// attributes of LLVMFuncOp `funcOp`.
   void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
 
+  /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+  /// operation `op` given the original instruction `inst`. Asserts if the
+  /// operation does not implement the integer arithmetic flag interface.
+  void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   return filteredAttrs;
 }
 
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute, 8> filteredAttrs(
+      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+        return attr.getName() != "arithFlags";
+      }));
+  return filteredAttrs;
+}
+
 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
                                     NamedAttrList &result) {
   return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+  printer.printOptionalAttrDict(
+      processFMFAttr(processIntArithAttr(attrs.getValue())));
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
   }
 }
 
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+                                       Operation *op) const {
+  IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+  IntegerArithFlags value = {};
+  value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+  value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+  IntegerArithFlagsAttr attr =
+      IntegerArithFlagsAttr::get(op->getContext(), value);
+  iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
                                         Operation *op) const {
   auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
   %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+  %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+  %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+  %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+  %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+  %1 = add nsw i64 %arg1, %arg2
+  ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+  %2 = sub nuw i64 %arg1, %arg2
+  ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %3 = mul nsw nuw i64 %arg1, %arg2
+  ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %4 = shl nuw nsw i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+  %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+  // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+  %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+  // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+  %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+  // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+  %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+  llvm.return
+}

>From 6c411c5fe9ebe3910f93529a4796f84006f7be28 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 11:44:41 +0000
Subject: [PATCH 2/8] Change name Arith -> Overflow

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 20 ++++++++--------
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 24 +++++++++----------
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 10 ++++----
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  4 ++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  6 ++---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 15 ++++++------
 6 files changed, 40 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 5cde4980ae17d..ec835e05258d8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -429,25 +429,25 @@ def DISubprogramFlags : I32BitEnumAttr<
 }
 
 //===----------------------------------------------------------------------===//
-// IntegerArithFlags
+// IntegerOverflowFlags
 //===----------------------------------------------------------------------===//
 
-def IAFnone : I32BitEnumAttrCaseNone<"none">;
-def IAFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
-def IAFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
 
-def IntegerArithFlags : I32BitEnumAttr<
-    "IntegerArithFlags",
-    "LLVM integer arithmetic flags",
-    [IAFnone, IAFnsw, IAFnuw]> {
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "LLVM integer overflow flags",
+    [IOFnone, IOFnsw, IOFnuw]> {
   let separator = ", ";
   let cppNamespace = "::mlir::LLVM";
   let genSpecializedAttr = 0;
   let printBitEnumPrimaryGroups = 1;
 }
 
-def LLVM_IntegerArithFlagsAttr :
-    EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+def LLVM_IntegerOverflowFlagsAttr :
+    EnumAttr<LLVM_Dialect, IntegerOverflowFlags, "overflow"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 3d3388ac50aff..775c47054d7a9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,7 +48,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
-def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
   let description = [{
     Access to op integer overflow flags.
   }];
@@ -57,14 +57,14 @@ def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
 
   let methods = [
     InterfaceMethod<
-      /*desc=*/        "Returns a IntegerArithFlagsAttr attribute for the operation",
-      /*returnType=*/  "IntegerArithFlagsAttr",
-      /*methodName=*/  "getArithAttr",
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
         auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getArithFlagsAttr();
+        return op.getOverflowFlagsAttr();
       }]
       >,
     InterfaceMethod<
@@ -75,8 +75,8 @@ def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
         auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
       }]
       >,
     InterfaceMethod<
@@ -87,19 +87,19 @@ def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
         auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
       }]
       >,
     StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the IntegerArithFlagsAttr attribute
+      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
                          for the operation}],
       /*returnType=*/  "StringRef",
-      /*methodName=*/  "getIntegerArithAttrName",
+      /*methodName=*/  "getIntegerOverflowAttrName",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        return "arithFlags";
+        return "overflowFlags";
       }]
       >
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 4a2ef07f505b4..feea3dccf1043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -58,16 +58,16 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
 class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
                                    list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
-    !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
-  dag iafArg = (
-    ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
-  let arguments = !con(commonArgs, iafArg);
+    !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
+  dag iofArg = (
+    ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
+  let arguments = !con(commonArgs, iofArg);
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
     moduleImport.setIntegerFlagsAttr(inst, op);
     $res = op;
   }];
-  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $overflowFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
   string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index de52476636aed..6cb6ad87e93cf 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,9 +172,9 @@ class ModuleImport {
   /// attributes of LLVMFuncOp `funcOp`.
   void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
 
-  /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+  /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
   /// operation `op` given the original instruction `inst`. Asserts if the
-  /// operation does not implement the integer arithmetic flag interface.
+  /// operation does not implement the integer overflow flag interface.
   void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
 
   /// Sets the fastmath flags attribute for the imported operation `op` given
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3d78970cf6c14..960e397683f3d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,10 +62,10 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   return filteredAttrs;
 }
 
-static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+static auto processIntOverflowAttr(ArrayRef<NamedAttribute> attrs) {
   SmallVector<NamedAttribute, 8> filteredAttrs(
       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
-        return attr.getName() != "arithFlags";
+        return attr.getName() != "overflowFlags";
       }));
   return filteredAttrs;
 }
@@ -78,7 +78,7 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
   printer.printOptionalAttrDict(
-      processFMFAttr(processIntArithAttr(attrs.getValue())));
+      processFMFAttr(processIntOverflowAttr(attrs.getValue())));
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index edd0120dcbb71..ec4ab9f4581d5 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -621,15 +621,16 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
 
 void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
                                        Operation *op) const {
-  IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+  auto iface = cast<IntegerOverflowFlagsInterface>(op);
 
-  IntegerArithFlags value = {};
-  value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
-  value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+  IntegerOverflowFlags value = {};
+  value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
+  value =
+      bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
 
-  IntegerArithFlagsAttr attr =
-      IntegerArithFlagsAttr::get(op->getContext(), value);
-  iface->setAttr(iface.getIntegerArithAttrName(), attr);
+  auto attr =
+      IntegerOverflowFlagsAttr::get(op->getContext(), value);
+  iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
 }
 
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,

>From a3d4c390dbc7b3206215ad8219a57b51eb5d1ee1 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 11:55:08 +0000
Subject: [PATCH 3/8] Rename hasN*w -> hasNo(Un)signedWrap

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td | 4 ++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td        | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 775c47054d7a9..81589eaf5fd0a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -70,7 +70,7 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
     InterfaceMethod<
       /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
       /*returnType=*/  "bool",
-      /*methodName=*/  "hasNuw",
+      /*methodName=*/  "hasNoUnsignedWrap",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
@@ -82,7 +82,7 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
     InterfaceMethod<
       /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
       /*returnType=*/  "bool",
-      /*methodName=*/  "hasNsw",
+      /*methodName=*/  "hasNoSignedWrap",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index feea3dccf1043..bcf8442291fee 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -68,7 +68,7 @@ class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
     $res = op;
   }];
   let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $overflowFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
-  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :

>From 305ffe11f596f4abb0ad17def64c3b6ac1efd854 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 12:00:54 +0000
Subject: [PATCH 4/8] Support Tablegen line limits

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index bcf8442291fee..e9aed1b62952a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -67,8 +67,13 @@ class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
     moduleImport.setIntegerFlagsAttr(inst, op);
     $res = op;
   }];
-  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $overflowFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
-  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
+  let assemblyFormat = [{
+    $lhs `,` $rhs (`flags` ` ` $overflowFlags^)?
+    custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+  }];
+  string llvmBuilder =
+    "$res = builder.Create" # instName #
+    "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :

>From 99e99473503e27606b9f80e869bd49107a672deb Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 12:17:40 +0000
Subject: [PATCH 5/8] Remove processIntOverflowAttr

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 12 ++----------
 1 file changed, 2 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 960e397683f3d..0741ee0846088 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,14 +62,6 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   return filteredAttrs;
 }
 
-static auto processIntOverflowAttr(ArrayRef<NamedAttribute> attrs) {
-  SmallVector<NamedAttribute, 8> filteredAttrs(
-      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
-        return attr.getName() != "overflowFlags";
-      }));
-  return filteredAttrs;
-}
-
 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
                                     NamedAttrList &result) {
   return parser.parseOptionalAttrDict(result);
@@ -77,8 +69,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(
-      processFMFAttr(processIntOverflowAttr(attrs.getValue())));
+  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()),
+                                /*elidedAttrs=*/{"overflowFlags"});
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and

>From 251b4689cdc459ffbdbadd14e4382f3433744f40 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 12:23:06 +0000
Subject: [PATCH 6/8] Don't hardcode integer overflow attr name

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0741ee0846088..79f177f56af69 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -69,8 +69,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()),
-                                /*elidedAttrs=*/{"overflowFlags"});
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+    printer.printOptionalAttrDict(
+        processFMFAttr(attrs.getValue()),
+        /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
+  else
+    printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and

>From 1684ebfa03ef0ff3ab9507d342783ac69a84ac32 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 12:37:36 +0000
Subject: [PATCH 7/8] Rename attribute designator flags -> overflow

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  2 +-
 mlir/test/Dialect/LLVMIR/roundtrip.mlir     | 16 ++++++++--------
 mlir/test/Target/LLVMIR/Import/nsw_nuw.ll   |  8 ++++----
 mlir/test/Target/LLVMIR/nsw_nuw.mlir        |  8 ++++----
 4 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e9aed1b62952a..52b33320b2f02 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -68,7 +68,7 @@ class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
     $res = op;
   }];
   let assemblyFormat = [{
-    $lhs `,` $rhs (`flags` ` ` $overflowFlags^)?
+    $lhs `,` $rhs (`overflow` $overflowFlags^)?
     custom<LLVMOpAttrs>(attr-dict) `:` type($res)
   }];
   string llvmBuilder =
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index dc0f9f453057d..735184bbfead1 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -35,14 +35,14 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
 // Integer arithmetic flags
-// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
-// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
-// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
-// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
-  %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
-  %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
-  %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
-  %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow <nsw, nuw> : i32
+  %add_flag = llvm.add %arg0, %arg0 overflow <nsw> : i32
+  %sub_flag = llvm.sub %arg0, %arg0 overflow <nuw> : i32
+  %mul_flag = llvm.mul %arg0, %arg0 overflow <nsw, nuw> : i32
+  %shl_flag = llvm.shl %arg0, %arg0 overflow <nuw, nsw> : i32
 
 // Floating point binary operations.
 //
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
index 2ea0425ec0ff7..6a1200e03782d 100644
--- a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -2,13 +2,13 @@
 
 ; CHECK-LABEL: @intflag_inst
 define void @intflag_inst(i64 %arg1, i64 %arg2) {
-  ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+  ; CHECK: llvm.add %{{.*}}, %{{.*}} overflow <nsw> : i64
   %1 = add nsw i64 %arg1, %arg2
-  ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+  ; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow <nuw> : i64
   %2 = sub nuw i64 %arg1, %arg2
-  ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  ; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow <nsw, nuw> : i64
   %3 = mul nsw nuw i64 %arg1, %arg2
-  ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow <nsw, nuw> : i64
   %4 = shl nuw nsw i64 %arg1, %arg2
   ret void
 }
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
index 4a7a39bb570c3..6843c2ef0299c 100644
--- a/mlir/test/Target/LLVMIR/nsw_nuw.mlir
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -3,12 +3,12 @@
 // CHECK-LABEL: define void @intflags_func
 llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
   // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
-  %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+  %0 = llvm.add %arg0, %arg1 overflow <nsw> : i64
   // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
-  %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+  %1 = llvm.sub %arg0, %arg1 overflow <nuw> : i64
   // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
-  %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+  %2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
   // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
-  %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+  %3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
   llvm.return
 }

>From a14b356dc10aeedf796d35276020373ebc690ad4 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 6 Dec 2023 12:50:26 +0000
Subject: [PATCH 8/8] Fix formatting

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ec4ab9f4581d5..d43fe1b56201f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -628,8 +628,7 @@ void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
   value =
       bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
 
-  auto attr =
-      IntegerOverflowFlagsAttr::get(op->getContext(), value);
+  auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
   iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
 }
 



More information about the Mlir-commits mailing list