[Mlir-commits] [mlir] 40afff7 - [mlir][LLVM] Add disjoint flag (#115855)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 15 04:48:05 PST 2024
Author: lfrenot
Date: 2024-11-15T13:48:01+01:00
New Revision: 40afff7bd95090a75bc68a0d26b8017cc0ae65c1
URL: https://github.com/llvm/llvm-project/commit/40afff7bd95090a75bc68a0d26b8017cc0ae65c1
DIFF: https://github.com/llvm/llvm-project/commit/40afff7bd95090a75bc68a0d26b8017cc0ae65c1.diff
LOG: [mlir][LLVM] Add disjoint flag (#115855)
The implementation is mostly based on the one existing for the exact
flag.
disjoint means that for each bit, that bit is zero in at least one of
the inputs. This allows the Or to be treated as an Add since no carry
can occur from any bit. If the disjoint keyword is present, the result
value of the or is a [poison
value](https://llvm.org/docs/LangRef.html#poisonvalues) if both inputs
have a one in the same bit position. For vectors, only the element
containing the bit is poison.
Added:
mlir/test/Target/LLVMIR/Import/disjoint.ll
mlir/test/Target/LLVMIR/disjoint.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleImport.h
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 352e2ec91bdbea..5ccddef158d9c2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}
+def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
+ let description = [{
+ This interface defines an LLVM operation with a disjoint flag and
+ provides a uniform API for accessing it.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get the disjoint flag for the operation.
+ }], "bool", "getIsDisjoint", (ins), [{}], [{
+ return $_op.getProperties().isDisjoint;
+ }]>,
+ InterfaceMethod<[{
+ Set the disjoint flag for the operation.
+ }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
+ $_op.getProperties().isDisjoint = isDisjoint;
+ }]>,
+ StaticInterfaceMethod<[{
+ Get the attribute name of the isDisjoint property.
+ }], "StringRef", "getIsDisjointName", (ins), [{}], [{
+ return "isDisjoint";
+ }]>,
+ ];
+}
+
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an nneg flag and
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 5636ee4d9a1109..25f4f616aecf5b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
+class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
+ let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
+
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setDisjointFlag(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = [{
+ (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
+ }];
+ string llvmBuilder = [{
+ auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
+ moduleTranslation.setDisjointFlag(op, inst);
+ $res = inst;
+ }];
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 30164843f63675..eea0647895b01b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
+ /// Sets the disjoint flag attribute for the imported operation `op`
+ /// given the original instruction `inst`. Asserts if the operation does
+ /// not implement the disjoint flag interface.
+ void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the nneg flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the nneg flag interface.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..1b62437761ed9d 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -167,6 +167,12 @@ class ModuleTranslation {
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
+ /// Sets the disjoint flag attribute for the exported instruction `value`
+ /// given the original operation `op`. Asserts if the operation does
+ /// not implement the disjoint flag interface, and asserts if the value
+ /// is an instruction that implements the disjoint flag.
+ void setDisjointFlag(Operation *op, llvm::Value *value);
+
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 71d88d3a62f2b9..0d416a5857facb 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}
+void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
+ Operation *op) const {
+ auto iface = cast<DisjointFlagInterface>(op);
+ auto instDisjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+ iface.setIsDisjoint(instDisjoint->isDisjoint());
+}
+
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<NonNegFlagInterface>(op);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..9e58d2a29199e6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}
+void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
+ auto iface = cast<DisjointFlagInterface>(op);
+ // We do a dyn_cast here in case the value got folded into a constant.
+ if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
+ disjointInst->setIsDisjoint(iface.getIsDisjoint());
+}
+
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index fb355678411e5a..aebfd7492093c1 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
+// Integer disjoint flag.
+// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
+ %or_flag = llvm.or disjoint %arg0, %arg0 : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll
new file mode 100644
index 00000000000000..36091c09043525
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll
@@ -0,0 +1,8 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @disjointflag_inst
+define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
+ %1 = or disjoint i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir
new file mode 100644
index 00000000000000..1f5a42e608ba40
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/disjoint.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @disjointflag_func
+llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
+ %0 = llvm.or disjoint %arg0, %arg1 : i64
+ llvm.return
+}
More information about the Mlir-commits
mailing list