[Mlir-commits] [mlir] [mlir][LLVM] Add disjoint flag (PR #115855)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 12 03:22:17 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (lfrenot)

<details>
<summary>Changes</summary>

The implementation is mostly based on the one existing for the exact flag.

disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a [poison value](https://llvm.org/docs/LangRef.html#poisonvalues) if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.

---
Full diff: https://github.com/llvm/llvm-project/pull/115855.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+27) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+21-1) 
- (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5) 
- (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+5) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+8) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+8) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+4) 
- (added) mlir/test/Target/LLVMIR/Import/disjoint.ll (+8) 
- (added) mlir/test/Target/LLVMIR/disjoint.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 352e2ec91bdbea..2699a0ed14d4b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
   ];
 }
 
+def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
+  let description = [{
+    This interface defines an LLVM operation with an disjoint flag and
+    provides a uniform API for accessing it.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<[{
+      Get the disjoint flag for the operation.
+    }], "bool", "getIsDisjoint", (ins), [{}], [{
+      return $_op.getProperties().isDisjoint;
+    }]>,
+    InterfaceMethod<[{
+      Set the disjoint flag for the operation.
+    }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
+      $_op.getProperties().isDisjoint = isDisjoint;
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the isDisjoint property.
+    }], "StringRef", "getIsDisjointName", (ins), [{}], [{
+      return "isDisjoint";
+    }]>,
+  ];
+}
+
 def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
   let description = [{
     This interface defines an LLVM operation with an nneg flag and
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 34f3e4b33b8295..3a3311d8469dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
     "$res = builder.Create" # instName #
     "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
 }
+class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
+  let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
+
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setDisjointFlag(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = [{
+    (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
+  }];
+  string llvmBuilder =
+    [{auto inst = builder.Create}] # instName #
+    [{($lhs, $rhs, /*Name=*/"");
+    moduleTranslation.setDisjointFlag(op, inst);
+    $res = inst;}];
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
 def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
 def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 30164843f63675..eea0647895b01b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
   /// implement the exact flag interface.
   void setExactFlag(llvm::Instruction *inst, Operation *op) const;
 
+  /// Sets the disjoint flag attribute for the imported operation `op`
+  /// given the original instruction `inst`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the nneg flag attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the nneg flag interface.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0b14a665337d58 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -167,6 +167,11 @@ class ModuleTranslation {
   /// attribute.
   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
 
+  /// Sets the disjoint flag attribute for the exported instruction `inst`
+  /// given the original operation `op`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(Operation *op, llvm::Value *inst);
+
   /// Converts the type from MLIR LLVM dialect to LLVM.
   llvm::Type *convertType(Type type);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 71d88d3a62f2b9..5592cc7f5df8f1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
   iface.setIsExact(inst->isExact());
 }
 
+void ModuleImport::setDisjointFlag(llvm::Instruction *inst, Operation *op) const {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  iface.setIsDisjoint(inst_disjoint->isDisjoint());
+}
+
 void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
   auto iface = cast<NonNegFlagInterface>(op);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..bbf567f8cf8d4c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1898,6 +1898,14 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
   inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
 }
 
+void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
+}
+
 llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index aa558bad2299ce..06f7b2d9f586fd 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
   %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
 
+// Integer disjoint flag.
+// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
+  %or_flag = llvm.or disjoint %arg0, %arg0 : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll
new file mode 100644
index 00000000000000..36091c09043525
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll
@@ -0,0 +1,8 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @disjointflag_inst
+define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
+  %1 = or disjoint i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir
new file mode 100644
index 00000000000000..1f5a42e608ba40
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/disjoint.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @disjointflag_func
+llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
+  %0 = llvm.or disjoint %arg0, %arg1 : i64
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/115855


More information about the Mlir-commits mailing list