[Mlir-commits] [mlir] [mlir][LLVM] Add nsw and nuw flags (PR #74508)
Tom Eccles
llvmlistbot at llvm.org
Tue Dec 5 10:26:10 PST 2023
https://github.com/tblah created https://github.com/llvm/llvm-project/pull/74508
The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.
>From 0c5f09c00af6800d26f71daace89139c29ed36c5 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 23 Nov 2023 17:54:13 +0000
Subject: [PATCH] [mlir][LLVM] Add nsw and nuw flags
The implementation of these are modelled after the existing fastmath flags for
floating point arithmetic.
---
mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 23 ++++++++
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 57 +++++++++++++++++++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 23 ++++++--
.../include/mlir/Target/LLVMIR/ModuleImport.h | 5 ++
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 11 +++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 13 +++++
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 ++++
mlir/test/Target/LLVMIR/Import/nsw_nuw.ll | 14 +++++
mlir/test/Target/LLVMIR/nsw_nuw.mlir | 14 +++++
9 files changed, 165 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
create mode 100644 mlir/test/Target/LLVMIR/nsw_nuw.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+ "IntegerArithFlags",
+ "LLVM integer arithmetic flags",
+ [IAFnone, IAFnsw, IAFnuw]> {
+ let separator = ", ";
+ let cppNamespace = "::mlir::LLVM";
+ let genSpecializedAttr = 0;
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+ EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
];
}
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+ let description = [{
+ Access to op integer overflow flags.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a IntegerArithFlagsAttr attribute for the operation",
+ /*returnType=*/ "IntegerArithFlagsAttr",
+ /*methodName=*/ "getArithAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getArithFlagsAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNuw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNsw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the IntegerArithFlagsAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getIntegerArithAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "arithFlags";
+ }]
+ >
+ ];
+}
+
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+ dag iafArg = (
+ ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+ let arguments = !con(commonArgs, iafArg);
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setIntegerFlagsAttr(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+ string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
}
// Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
/// attributes of LLVMFuncOp `funcOp`.
void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
+ /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+ /// operation `op` given the original instruction `inst`. Asserts if the
+ /// operation does not implement the integer arithmetic flag interface.
+ void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
return filteredAttrs;
}
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute, 8> filteredAttrs(
+ llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+ return attr.getName() != "arithFlags";
+ }));
+ return filteredAttrs;
+}
+
static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
NamedAttrList &result) {
return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
- printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+ printer.printOptionalAttrDict(
+ processFMFAttr(processIntArithAttr(attrs.getValue())));
}
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
}
}
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+ Operation *op) const {
+ IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+ IntegerArithFlags value = {};
+ value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+ value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+ IntegerArithFlagsAttr attr =
+ IntegerArithFlagsAttr::get(op->getContext(), value);
+ iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
%typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+ %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+ %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+ %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+ %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+ %1 = add nsw i64 %arg1, %arg2
+ ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+ %2 = sub nuw i64 %arg1, %arg2
+ ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %3 = mul nsw nuw i64 %arg1, %arg2
+ ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %4 = shl nuw nsw i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+ %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+ // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+ %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+ // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+ %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+ // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+ %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+ llvm.return
+}
More information about the Mlir-commits
mailing list