[Mlir-commits] [mlir] [mlir][LLVM] Add disjoint flag (PR #115855)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 12 04:50:57 PST 2024


=?utf-8?q?Léon?= Frenot <leon.frenot at ens-lyon.fr>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/115855 at github.com>


https://github.com/lfrenot updated https://github.com/llvm/llvm-project/pull/115855

>From ce0b93b362877e3cb3aefad82841d735dd3fcceb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9on=20Frenot?= <leon.frenot at ens-lyon.fr>
Date: Tue, 12 Nov 2024 10:01:27 +0000
Subject: [PATCH 1/2] Add disjoint flag

---
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 27 +++++++++++++++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 22 ++++++++++++++-
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  5 ++++
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  5 ++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  9 +++++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  8 ++++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       |  4 +++
 mlir/test/Target/LLVMIR/Import/disjoint.ll    |  8 ++++++
 mlir/test/Target/LLVMIR/disjoint.mlir         |  8 ++++++
 9 files changed, 95 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/disjoint.ll
 create mode 100644 mlir/test/Target/LLVMIR/disjoint.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 352e2ec91bdbea..2699a0ed14d4b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
   ];
 }
 
+def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
+  let description = [{
+    This interface defines an LLVM operation with an disjoint flag and
+    provides a uniform API for accessing it.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<[{
+      Get the disjoint flag for the operation.
+    }], "bool", "getIsDisjoint", (ins), [{}], [{
+      return $_op.getProperties().isDisjoint;
+    }]>,
+    InterfaceMethod<[{
+      Set the disjoint flag for the operation.
+    }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
+      $_op.getProperties().isDisjoint = isDisjoint;
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the isDisjoint property.
+    }], "StringRef", "getIsDisjointName", (ins), [{}], [{
+      return "isDisjoint";
+    }]>,
+  ];
+}
+
 def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
   let description = [{
     This interface defines an LLVM operation with an nneg flag and
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 34f3e4b33b8295..3a3311d8469dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
     "$res = builder.Create" # instName #
     "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
 }
+class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
+  let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
+
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setDisjointFlag(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = [{
+    (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
+  }];
+  string llvmBuilder =
+    [{auto inst = builder.Create}] # instName #
+    [{($lhs, $rhs, /*Name=*/"");
+    moduleTranslation.setDisjointFlag(op, inst);
+    $res = inst;}];
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
 def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
 def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 30164843f63675..eea0647895b01b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
   /// implement the exact flag interface.
   void setExactFlag(llvm::Instruction *inst, Operation *op) const;
 
+  /// Sets the disjoint flag attribute for the imported operation `op`
+  /// given the original instruction `inst`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the nneg flag attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the nneg flag interface.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0b14a665337d58 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -167,6 +167,11 @@ class ModuleTranslation {
   /// attribute.
   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
 
+  /// Sets the disjoint flag attribute for the exported instruction `inst`
+  /// given the original operation `op`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(Operation *op, llvm::Value *inst);
+
   /// Converts the type from MLIR LLVM dialect to LLVM.
   llvm::Type *convertType(Type type);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 71d88d3a62f2b9..06a4aa59f91063 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,15 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
   iface.setIsExact(inst->isExact());
 }
 
+void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
+                                   Operation *op) const {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  iface.setIsDisjoint(inst_disjoint->isDisjoint());
+}
+
 void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
   auto iface = cast<NonNegFlagInterface>(op);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..bbf567f8cf8d4c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1898,6 +1898,14 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
   inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
 }
 
+void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
+}
+
 llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index aa558bad2299ce..06f7b2d9f586fd 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
   %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
 
+// Integer disjoint flag.
+// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
+  %or_flag = llvm.or disjoint %arg0, %arg0 : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll
new file mode 100644
index 00000000000000..36091c09043525
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll
@@ -0,0 +1,8 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @disjointflag_inst
+define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
+  %1 = or disjoint i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir
new file mode 100644
index 00000000000000..1f5a42e608ba40
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/disjoint.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @disjointflag_func
+llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
+  %0 = llvm.or disjoint %arg0, %arg1 : i64
+  llvm.return
+}

>From 55e1b05af7611171a80f241ffbe4b56e63687018 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9on=20Frenot?= <leon.frenot at ens-lyon.fr>
Date: Tue, 12 Nov 2024 12:46:47 +0000
Subject: [PATCH 2/2] nit fixes

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td | 2 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td        | 8 ++++----
 mlir/lib/Target/LLVMIR/ModuleImport.cpp            | 1 -
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp       | 1 -
 4 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 2699a0ed14d4b3..5ccddef158d9c2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -116,7 +116,7 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
 
 def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
   let description = [{
-    This interface defines an LLVM operation with an disjoint flag and
+    This interface defines an LLVM operation with a disjoint flag and
     provides a uniform API for accessing it.
   }];
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3a3311d8469dfd..847ff6def34b88 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -107,11 +107,11 @@ class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
   let assemblyFormat = [{
     (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
   }];
-  string llvmBuilder =
-    [{auto inst = builder.Create}] # instName #
-    [{($lhs, $rhs, /*Name=*/"");
+  string llvmBuilder = [{
+    auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
     moduleTranslation.setDisjointFlag(op, inst);
-    $res = inst;}];
+    $res = inst;
+  }];
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 06a4aa59f91063..c31c6ef44cf5af 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -692,7 +692,6 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
 void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
                                    Operation *op) const {
   auto iface = cast<DisjointFlagInterface>(op);
-
   auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
 
   iface.setIsDisjoint(inst_disjoint->isDisjoint());
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index bbf567f8cf8d4c..c9cfa0185926f2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1900,7 +1900,6 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
 
 void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
   auto iface = cast<DisjointFlagInterface>(op);
-
   auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
 
   inst_disjoint->setIsDisjoint(iface.getIsDisjoint());



More information about the Mlir-commits mailing list