[Mlir-commits] [mlir] e9e1c41 - [mlir][LLVM] Add nsw and nuw flags (#74508)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 7 02:35:04 PST 2023


Author: Tom Eccles
Date: 2023-12-07T10:35:00Z
New Revision: e9e1c411b6db8fb739c2c7af0d41bdd48eeed3e5

URL: https://github.com/llvm/llvm-project/commit/e9e1c411b6db8fb739c2c7af0d41bdd48eeed3e5
DIFF: https://github.com/llvm/llvm-project/commit/e9e1c411b6db8fb739c2c7af0d41bdd48eeed3e5.diff

LOG: [mlir][LLVM] Add nsw and nuw flags (#74508)

The implementation of these are modeled after the existing fastmath
flags for floating point arithmetic.

Added: 
    mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
    mlir/test/Target/LLVMIR/nsw_nuw.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Dialect/LLVMIR/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 3b5984498cf83..a7b269eb41ee2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
   let printBitEnumPrimaryGroups = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "LLVM integer overflow flags",
+    [IOFnone, IOFnsw, IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::LLVM";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerOverflowFlagsAttr :
+    EnumAttr<LLVM_Dialect, IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // FastmathFlags
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..81589eaf5fd0a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getOverflowFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoUnsignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoSignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerOverflowAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "overflowFlags";
+      }]
+      >
+  ];
+}
+
 def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
   let description = [{
     An interface for operations that can carry branch weights metadata. It

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 92460fa06f530..88d9cd2c71c0c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,26 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
     $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
   }];
 }
+class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
+  dag iofArg = (
+    ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
+  let arguments = !con(commonArgs, iofArg);
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setIntegerOverflowFlagsAttr(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = [{
+    $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+    custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+  }];
+  string llvmBuilder =
+    "$res = builder.Create" # instName #
+    "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +110,11 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
 }
 
 // Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
+    [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
+    [Commutative]>;
 def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
 def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +124,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
   let hasFolder = 1;
 }
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..b49d2f539453e 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,12 @@ class ModuleImport {
   /// attributes of LLVMFuncOp `funcOp`.
   void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
 
+  /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
+  /// operation `op` given the original instruction `inst`. Asserts if the
+  /// operation does not implement the integer overflow flag interface.
+  void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
+                                   Operation *op) const;
+
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the fastmath interface.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c67bba04d6971..53e1088f620d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -69,7 +69,13 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+  auto filteredAttrs = processFMFAttr(attrs.getValue());
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+    printer.printOptionalAttrDict(
+        filteredAttrs,
+        /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
+  else
+    printer.printOptionalAttrDict(filteredAttrs);
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 4bdffa572e31a..7c51ee7420f9b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
   }
 }
 
+void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
+                                               Operation *op) const {
+  auto iface = cast<IntegerOverflowFlagsInterface>(op);
+
+  IntegerOverflowFlags value = {};
+  value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
+  value =
+      bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
+
+  auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
+  iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
+}
+
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
                                         Operation *op) const {
   auto iface = cast<FastmathFlagsInterface>(op);

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 42589972ef2ba..594c3de91815a 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
   %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
+// Integer overflow flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow<nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow<nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
+  %add_flag = llvm.add %arg0, %arg0 overflow<nsw> : i32
+  %sub_flag = llvm.sub %arg0, %arg0 overflow<nuw> : i32
+  %mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
+  %shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32

diff  --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..d08098a5e5dfe
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
+  %1 = add nsw i64 %arg1, %arg2
+  ; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
+  %2 = sub nuw i64 %arg1, %arg2
+  ; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %3 = mul nsw nuw i64 %arg1, %arg2
+  ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %4 = shl nuw nsw i64 %arg1, %arg2
+  ret void
+}

diff  --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..6843c2ef0299c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+  %0 = llvm.add %arg0, %arg1 overflow <nsw> : i64
+  // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+  %1 = llvm.sub %arg0, %arg1 overflow <nuw> : i64
+  // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+  %2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
+  // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+  %3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
+  llvm.return
+}


        


More information about the Mlir-commits mailing list