[Mlir-commits] [mlir] [mlir][LLVM] Add exact flag (PR #115327)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 7 07:05:02 PST 2024
https://github.com/lfrenot updated https://github.com/llvm/llvm-project/pull/115327
>From 85ce594b9fa4156c8a561db57f38f1de60919dcb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9on=20Frenot?= <leon.frenot at ens-lyon.fr>
Date: Thu, 7 Nov 2024 15:04:32 +0000
Subject: [PATCH] [mlir][LLVM] Add exact flag
---
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 27 +++++++++++++++++++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 26 +++++++++++++++---
.../include/mlir/Target/LLVMIR/ModuleImport.h | 5 ++++
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +++
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 +++++++
mlir/test/Target/LLVMIR/Import/exact.ll | 14 ++++++++++
mlir/test/Target/LLVMIR/exact.mlir | 14 ++++++++++
8 files changed, 103 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/Import/exact.ll
create mode 100644 mlir/test/Target/LLVMIR/exact.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 7e38e0b27fd96b..12c430df208925 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
];
}
+def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
+ let description = [{
+ This interface defines an LLVM operation with an exact flag and
+ provides a uniform API for accessing it.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get the exact flag for the operation.
+ }], "bool", "getIsExact", (ins), [{}], [{
+ return $_op.getProperties().isExact;
+ }]>,
+ InterfaceMethod<[{
+ Set the exact flag for the operation.
+ }], "void", "setIsExact", (ins "bool":$isExact), [{}], [{
+ $_op.getProperties().isExact = isExact;
+ }]>,
+ StaticInterfaceMethod<[{
+ Get the attribute name of the isExact property.
+ }], "StringRef", "getIsExactName", (ins), [{}], [{
+ return "isExact";
+ }]>,
+ ];
+}
+
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d5def510a904d3..b7ce126dbf54dd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
}
+class LLVM_IntArithmeticOpWithIsExact<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<ExactFlagInterface>], traits)> {
+ let arguments = !con(commonArgs, (ins UnitAttr:$isExact));
+
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setExactFlag(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = [{
+ (`exact` $isExact^)? $lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+ }];
+ string llvmBuilder =
+ "$res = builder.Create" # instName #
+ "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -116,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
[Commutative]>;
-def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
-def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
+def LLVM_UDivOp : LLVM_IntArithmeticOpWithIsExact<"udiv", "UDiv">;
+def LLVM_SDivOp : LLVM_IntArithmeticOpWithIsExact<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
@@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
-def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
-def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
+def LLVM_LShrOp : LLVM_IntArithmeticOpWithIsExact<"lshr", "LShr">;
+def LLVM_AShrOp : LLVM_IntArithmeticOpWithIsExact<"ashr", "AShr">;
// Base class for compare operations. A compare operation takes two operands
// of the same type and returns a boolean result. If the operands are
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index bbb7af58d27393..6c3a500f20e3a9 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -187,6 +187,11 @@ class ModuleImport {
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
+ /// Sets the exact flag attribute for the imported operation `op` given
+ /// the original instruction `inst`. Asserts if the operation does not
+ /// implement the exact flag interface.
+ void setExactFlag(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c9bc9533ca2a6b..6b2d8943bf4885 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
printer.printOptionalAttrDict(
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
+ } else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
+ printer.printOptionalAttrDict(filteredAttrs,
+ /*elidedAttrs=*/{iface.getIsExactName()});
} else {
printer.printOptionalAttrDict(filteredAttrs);
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 1f63519373ecab..ccec2034a298b2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -683,6 +683,14 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
iface.setOverflowFlags(value);
}
+void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
+ auto iface = cast<ExactFlagInterface>(op);
+
+ bool value = inst->isExact();
+
+ iface.setIsExact(value);
+}
+
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b8ce7db795a1d1..9daad2ef5b0b1b 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
+// Integer exact
+// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32
+// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32
+// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32
+// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32
+ %sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32
+ %udiv_flag = llvm.udiv exact %arg0, %arg0 : i32
+ %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
+ %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/exact.ll b/mlir/test/Target/LLVMIR/Import/exact.ll
new file mode 100644
index 00000000000000..528fee5091d2da
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/exact.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @exactflag_inst
+define void @exactflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64
+ %1 = udiv exact i64 %arg1, %arg2
+ ; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64
+ %2 = sdiv exact i64 %arg1, %arg2
+ ; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64
+ %3 = lshr exact i64 %arg1, %arg2
+ ; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64
+ %4 = ashr exact i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/exact.mlir b/mlir/test/Target/LLVMIR/exact.mlir
new file mode 100644
index 00000000000000..b6c378c2fdcc94
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/exact.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @exactflag_func
+llvm.func @exactflag_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}}
+ %0 = llvm.udiv exact %arg0, %arg1 : i64
+ // CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
+ %1 = llvm.sdiv exact %arg0, %arg1 : i64
+ // CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}}
+ %2 = llvm.lshr exact %arg0, %arg1 : i64
+ // CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}}
+ %3 = llvm.ashr exact %arg0, %arg1 : i64
+ llvm.return
+}
More information about the Mlir-commits
mailing list