[Mlir-commits] [mlir] cf487cc - [mlir][llvm] Make the import of LLVM IR intrinsics extensible.

Tobias Gysi llvmlistbot at llvm.org
Mon Jan 2 02:52:24 PST 2023


Author: Tobias Gysi
Date: 2023-01-02T11:35:44+01:00
New Revision: cf487cce6f64fe2a397d43108bfdbad01a8754fb

URL: https://github.com/llvm/llvm-project/commit/cf487cce6f64fe2a397d43108bfdbad01a8754fb
DIFF: https://github.com/llvm/llvm-project/commit/cf487cce6f64fe2a397d43108bfdbad01a8754fb.diff

LOG: [mlir][llvm] Make the import of LLVM IR intrinsics extensible.

The revision introduces the LLVMImportDialectInterface to make the
import of LLVM IR intrinsics extensible. It uses a dialect interface
that enables external projects to provide their own conversion functions
for custom intrinsics. These conversion functions can rely on the
ModuleImport class to perform support tasks such as mapping LLVM
values to MLIR values or for converting types between the two worlds.

The implementation largely mirrors the export implementation. One major
difference is the dispatch to the appropriate dialect interface, since
LLVM IR intrinsics have no direct association to an MLIR dialect. The
dialect interfaces thus have to publish the supported intrinsics to
ensure incoming conversion calls are dispatched to the right dialect
interface.

The revision implements the extensible intrinsic import discussed as
part of the "extensible llvm ir import" rfc:
https://discourse.llvm.org/t/rfc-extensible-llvm-ir-import/67256/6

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140374

Added: 
    mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h
    mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/Dialect/All.h
    mlir/include/mlir/Target/LLVMIR/Import.h
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/include/mlir/Tools/mlir-translate/Translation.h
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Tools/mlir-translate/Translation.cpp
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 087ccbf88fc02..048281f146e06 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -378,7 +378,8 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
     # !if(!gt(numResults, 0), "$res = inst;", "");
 
   string mlirBuilder = [{
-    FailureOr<SmallVector<Value>> mlirOperands = convertValues(llvmOperands);
+    FailureOr<SmallVector<Value>> mlirOperands =
+      moduleImport.convertValues(llvmOperands);
     if (failed(mlirOperands))
       return failure();
     SmallVector<Type> resultTypes =
@@ -386,7 +387,7 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
     auto op = $_builder.create<$_qualCppClassName>(
       $_location, resultTypes, *mlirOperands);
     }] # !if(!gt(requiresFastmath, 0),
-      "setFastmathFlagsAttr(inst, op);", "")
+      "moduleImport.setFastmathFlagsAttr(inst, op);", "")
     # !if(!gt(numResults, 0), "$res = op;", "(void)op;");
 }
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f7f168bf6915e..2e46a7180201b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -62,7 +62,7 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName,
   let arguments = !con(commonArgs, fmfArg);
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
-    setFastmathFlagsAttr(inst, op);
+    moduleImport.setFastmathFlagsAttr(inst, op);
     $res = op;
   }];
 }
@@ -82,7 +82,7 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
   string llvmInstName = instName;
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $operand);
-    setFastmathFlagsAttr(inst, op);
+    moduleImport.setFastmathFlagsAttr(inst, op);
     $res = op;
    }];
 }
@@ -157,7 +157,7 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
     auto *fCmpInst = cast<llvm::FCmpInst>(inst);
     auto op = $_builder.create<$_qualCppClassName>(
       $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs);
-    setFastmathFlagsAttr(inst, op);
+    moduleImport.setFastmathFlagsAttr(inst, op);
     $res = op;
   }];
   // Set the $predicate index to -1 to indicate there is no matching operand
@@ -227,7 +227,8 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
   // FIXME: Import attributes.
   string mlirBuilder = [{
     auto *allocaInst = cast<llvm::AllocaInst>(inst);
-    Type allocatedType = convertType(allocaInst->getAllocatedType());
+    Type allocatedType =
+      moduleImport.convertType(allocaInst->getAllocatedType());
     unsigned alignment = allocaInst->getAlign().value();
     $res = $_builder.create<LLVM::AllocaOp>(
       $_location, $_resultType, allocatedType, $arraySize, alignment);
@@ -825,7 +826,8 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [Pure, ReturnLike]> {
       builder.CreateRetVoid();
   }];
   string mlirBuilder = [{
-    FailureOr<SmallVector<Value>> mlirOperands = convertValues(llvmOperands);
+    FailureOr<SmallVector<Value>> mlirOperands =
+      moduleImport.convertValues(llvmOperands);
     if (failed(mlirOperands))
       return failure();
     $_builder.create<LLVM::ReturnOp>($_location, *mlirOperands);

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 2bbfd7a45d09d..b9e52975692cd 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -17,6 +17,7 @@
 #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
@@ -40,6 +41,13 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerROCDLDialectTranslation(registry);
   registerX86VectorDialectTranslation(registry);
 }
+
+/// Registers all dialects that can be translated from LLVM IR and the
+/// corresponding translation interfaces.
+static inline void
+registerAllFromLLVMIRTranslations(DialectRegistry &registry) {
+  registerLLVMDialectImport(registry);
+}
 } // namespace mlir
 
 #endif // MLIR_TARGET_LLVMIR_DIALECT_ALL_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h
new file mode 100644
index 0000000000000..e3a0a31be9981
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h
@@ -0,0 +1,31 @@
+//===- LLVMIRToLLVMTranslation.h - LLVM IR to LLVM Dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for LLVM IR to LLVM dialect translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Registers the LLVM dialect and its import from LLVM IR in the given
+/// registry.
+void registerLLVMDialectImport(DialectRegistry &registry);
+
+/// Registers the LLVM dialect and its import from LLVM IR with the given
+/// context.
+void registerLLVMDialectImport(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMTRANSLATION_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/Import.h b/mlir/include/mlir/Target/LLVMIR/Import.h
index b7b1268686566..808c0f061d595 100644
--- a/mlir/include/mlir/Target/LLVMIR/Import.h
+++ b/mlir/include/mlir/Target/LLVMIR/Import.h
@@ -30,10 +30,11 @@ class DataLayoutSpecInterface;
 class MLIRContext;
 class ModuleOp;
 
-/// Convert the given LLVM module into MLIR's LLVM dialect.  The LLVM context is
-/// extracted from the registered LLVM IR dialect. In case of error, report it
-/// to the error handler registered with the MLIR context, if any (obtained from
-/// the MLIR module), and return `{}`.
+/// Translates the LLVM module into an MLIR module living in the given context.
+/// The translation supports operations from any dialect that has a registered
+/// implementation of the LLVMImportDialectInterface. It returns nullptr if the
+/// translation fails and reports errors using the error handler registered with
+/// the MLIR context.
 OwningOpRef<ModuleOp>
 translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
                         MLIRContext *context);

diff  --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
new file mode 100644
index 0000000000000..014107a4526bc
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -0,0 +1,124 @@
+//===- LLVMImportInterface.h - Import from LLVM interface -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines dialect interfaces for the LLVM IR import.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H
+#define MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace llvm {
+class IRBuilderBase;
+} // namespace llvm
+
+namespace mlir {
+namespace LLVM {
+class ModuleImport;
+} // namespace LLVM
+
+/// Base class for dialect interfaces used to import LLVM IR. Dialects that can
+/// be imported should provide an implementation of this interface for the
+/// supported intrinsics. The interface may be implemented in a separate library
+/// to avoid the "main" dialect library depending on LLVM IR. The interface can
+/// be attached using the delayed registration mechanism available in
+/// DialectRegistry.
+class LLVMImportDialectInterface
+    : public DialectInterface::Base<LLVMImportDialectInterface> {
+public:
+  LLVMImportDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Hook for derived dialect interfaces to implement the import of
+  /// intrinsics into MLIR.
+  virtual LogicalResult
+  convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
+                   LLVM::ModuleImport &moduleImport) const {
+    return failure();
+  }
+
+  /// Hook for derived dialect interfaces to publish the supported intrinsics.
+  /// As every LLVM IR intrinsic has a unique integer identifier, the function
+  /// returns the list of supported intrinsic identifiers.
+  virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }
+};
+
+/// Interface collection for the import of LLVM IR that dispatches to a concrete
+/// dialect interface implementation. Queries the dialect interfaces to obtain a
+/// list of the supported LLVM IR constructs and then builds a mapping for the
+/// efficient dispatch.
+class LLVMImportInterface
+    : public DialectInterfaceCollection<LLVMImportDialectInterface> {
+public:
+  using Base::Base;
+
+  /// Queries all dialect interfaces to build a map from intrinsic identifiers
+  /// to the dialect interface that supports importing the intrinsic. Returns
+  /// failure if multiple dialect interfaces translate the same LLVM IR
+  /// intrinsic.
+  LogicalResult initializeImport() {
+    for (const LLVMImportDialectInterface &iface : *this) {
+      // Verify the supported intrinsics have not been mapped before.
+      const auto *it =
+          llvm::find_if(iface.getSupportedIntrinsics(), [&](unsigned id) {
+            return intrinsicToDialect.count(id);
+          });
+      if (it != iface.getSupportedIntrinsics().end()) {
+        return emitError(
+            UnknownLoc::get(iface.getContext()),
+            llvm::formatv("expected unique conversion for intrinsic ({0}), but "
+                          "got conflicting {1} and {2} conversions",
+                          *it, iface.getDialect()->getNamespace(),
+                          intrinsicToDialect.lookup(*it)->getNamespace()));
+      }
+      // Add a mapping for all supported intrinsic identifiers.
+      for (unsigned id : iface.getSupportedIntrinsics())
+        intrinsicToDialect[id] = iface.getDialect();
+    }
+
+    return success();
+  }
+
+  /// Converts the LLVM intrinsic to an MLIR operation if a conversion exists.
+  /// Returns failure otherwise.
+  LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
+                                 LLVM::ModuleImport &moduleImport) const {
+    // Lookup the dialect interface for the given intrinsic.
+    Dialect *dialect = intrinsicToDialect.lookup(inst->getIntrinsicID());
+    if (!dialect)
+      return failure();
+
+    // Dispatch the conversion to the dialect interface.
+    const LLVMImportDialectInterface *iface = getInterfaceFor(dialect);
+    assert(iface && "expected to find a dialect interface");
+    return iface->convertIntrinsic(builder, inst, moduleImport);
+  }
+
+  /// Returns true if the given LLVM IR intrinsic is convertible to an MLIR
+  /// operation.
+  bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
+    return intrinsicToDialect.count(id);
+  }
+
+private:
+  DenseMap<unsigned, Dialect *> intrinsicToDialect;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 68f3eee582573..903160f09e714 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Target/LLVMIR/Import.h"
+#include "mlir/Target/LLVMIR/LLVMImportInterface.h"
 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
 
 namespace llvm {
@@ -44,6 +45,18 @@ class ModuleImport {
 public:
   ModuleImport(ModuleOp mlirModule, std::unique_ptr<llvm::Module> llvmModule);
 
+  /// Calls the LLVMImportInterface initialization that queries the registered
+  /// dialect interfaces for the supported LLVM IR intrinsics and builds the
+  /// dispatch table. Returns failure if multiple dialect interfaces translate
+  /// the same LLVM IR intrinsic.
+  LogicalResult initializeImportInterface() { return iface.initializeImport(); }
+
+  /// Converts all functions of the LLVM module to MLIR functions.
+  LogicalResult convertFunctions();
+
+  /// Converts all global variables of the LLVM module to MLIR global variables.
+  LogicalResult convertGlobals();
+
   /// Stores the mapping between an LLVM value and its MLIR counterpart.
   void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; }
 
@@ -95,16 +108,6 @@ class ModuleImport {
     return typeTranslator.translateType(type);
   }
 
-  /// Converts an LLVM intrinsic to an MLIR LLVM dialect operation if an MLIR
-  /// counterpart exists. Otherwise, returns failure.
-  LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst,
-                                 llvm::Intrinsic::ID intrinsicID);
-
-  /// Converts an LLVM instruction to an MLIR LLVM dialect operation if an MLIR
-  /// counterpart exists. Otherwise, returns failure.
-  LogicalResult convertOperation(OpBuilder &odsBuilder,
-                                 llvm::Instruction *inst);
-
   /// Imports `func` into the current module.
   LogicalResult processFunction(llvm::Function *func);
 
@@ -115,11 +118,10 @@ class ModuleImport {
   /// Imports `globalVar` as a GlobalOp, creating it if it doesn't exist.
   GlobalOp processGlobal(llvm::GlobalVariable *globalVar);
 
-  /// Converts all functions of the LLVM module to MLIR functions.
-  LogicalResult convertFunctions();
-
-  /// Converts all global variables of the LLVM module to MLIR global variables.
-  LogicalResult convertGlobals();
+  /// Sets the fastmath flags attribute for the imported operation `op` given
+  /// the original instruction `inst`. Asserts if the operation does not
+  /// implement the fastmath interface.
+  void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const;
 
 private:
   /// Clears the block and value mapping before processing a new region.
@@ -133,14 +135,17 @@ class ModuleImport {
     constantInsertionOp = nullptr;
   }
 
-  /// Sets the fastmath flags attribute for the imported operation `op` given
-  /// the original instruction `inst`. Asserts if the operation does not
-  /// implement the fastmath interface.
-  void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const;
   /// Returns personality of `func` as a FlatSymbolRefAttr.
   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
   /// Imports `bb` into `block`, which must be initially empty.
   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
+  /// Converts an LLVM intrinsic to an MLIR LLVM dialect operation if an MLIR
+  /// counterpart exists. Otherwise, returns failure.
+  LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst);
+  /// Converts an LLVM instruction to an MLIR LLVM dialect operation if an MLIR
+  /// counterpart exists. Otherwise, returns failure.
+  LogicalResult convertInstruction(OpBuilder &odsBuilder,
+                                   llvm::Instruction *inst);
   /// Imports `inst` and populates valueMapping[inst] with the result of the
   /// imported operation.
   LogicalResult processInstruction(llvm::Instruction *inst);
@@ -192,6 +197,10 @@ class ModuleImport {
   /// The LLVM module being imported.
   std::unique_ptr<llvm::Module> llvmModule;
 
+  /// A dialect interface collection used for dispatching the import to specific
+  /// dialects.
+  LLVMImportInterface iface;
+
   /// Function-local mapping between original and imported block.
   DenseMap<llvm::BasicBlock *, Block *> blockMapping;
   /// Function-local mapping between original and imported values.

diff  --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h
index 7d1896f6db7a4..a4ecaeba15a44 100644
--- a/mlir/include/mlir/Tools/mlir-translate/Translation.h
+++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h
@@ -51,6 +51,10 @@ using TranslateFunction = std::function<LogicalResult(
     const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
     llvm::raw_ostream &output, MLIRContext *)>;
 
+/// Interface of the function that adds all dialects and dialect extensions used
+/// for the translation to the given DialectRegistry.
+using DialectRegistrationFunction = std::function<void(DialectRegistry &)>;
+
 /// This class contains all of the components necessary for performing a
 /// translation.
 class Translation {
@@ -104,14 +108,20 @@ struct TranslateToMLIRRegistration {
   TranslateToMLIRRegistration(
       llvm::StringRef name, llvm::StringRef description,
       const TranslateSourceMgrToMLIRFunction &function,
+      const DialectRegistrationFunction &dialectRegistration =
+          [](DialectRegistry &) {},
       Optional<llvm::Align> inputAlignment = std::nullopt);
   TranslateToMLIRRegistration(
       llvm::StringRef name, llvm::StringRef description,
       const TranslateRawSourceMgrToMLIRFunction &function,
+      const DialectRegistrationFunction &dialectRegistration =
+          [](DialectRegistry &) {},
       Optional<llvm::Align> inputAlignment = std::nullopt);
   TranslateToMLIRRegistration(
       llvm::StringRef name, llvm::StringRef description,
       const TranslateStringRefToMLIRFunction &function,
+      const DialectRegistrationFunction &dialectRegistration =
+          [](DialectRegistry &) {},
       Optional<llvm::Align> inputAlignment = std::nullopt);
 };
 
@@ -119,14 +129,14 @@ struct TranslateFromMLIRRegistration {
   TranslateFromMLIRRegistration(
       llvm::StringRef name, llvm::StringRef description,
       const TranslateFromMLIRFunction &function,
-      const std::function<void(DialectRegistry &)> &dialectRegistration =
+      const DialectRegistrationFunction &dialectRegistration =
           [](DialectRegistry &) {});
 
   template <typename FuncTy, typename OpTy = detail::first_argument<FuncTy>,
             typename = std::enable_if_t<!std::is_same_v<OpTy, Operation *>>>
   TranslateFromMLIRRegistration(
       llvm::StringRef name, llvm::StringRef description, FuncTy function,
-      const std::function<void(DialectRegistry &)> &dialectRegistration =
+      const DialectRegistrationFunction &dialectRegistration =
           [](DialectRegistry &) {})
       : TranslateFromMLIRRegistration(
             name, description,
@@ -137,7 +147,7 @@ struct TranslateFromMLIRRegistration {
                      << "expected a '" << OpTy::getOperationName()
                      << "' op, got '" << op->getName().getStringRef() << "'";
             },
-            dialectRegistration){}
+            dialectRegistration) {}
 };
 struct TranslateRegistration {
   TranslateRegistration(llvm::StringRef name, llvm::StringRef description,

diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 384f663a4d7bd..97577c036a220 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -52,7 +52,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
   )
 
 add_mlir_translation_library(MLIRTargetLLVMIRImport
-  ConvertFromLLVMIR.cpp
   DebugImporter.cpp
   ModuleImport.cpp
   TypeFromLLVM.cpp
@@ -69,3 +68,10 @@ add_mlir_translation_library(MLIRTargetLLVMIRImport
   MLIRLLVMDialect
   MLIRTranslateLib
   )
+
+add_mlir_translation_library(MLIRFromLLVMIRTranslationRegistration
+  ConvertFromLLVMIR.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMIRToLLVMTranslation
+  )

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index dfcdf1d3df2b3..2e05f01119a05 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -11,7 +11,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVMIR/Dialect/All.h"
 #include "mlir/Target/LLVMIR/Import.h"
 #include "mlir/Tools/mlir-translate/Translation.h"
 #include "llvm/IR/Module.h"
@@ -23,7 +25,7 @@ using namespace mlir;
 namespace mlir {
 void registerFromLLVMIRTranslation() {
   TranslateToMLIRRegistration registration(
-      "import-llvm", "translate mlir to llvmir",
+      "import-llvm", "translate llvmir to mlir",
       [](llvm::SourceMgr &sourceMgr,
          MLIRContext *context) -> OwningOpRef<Operation *> {
         llvm::SMDiagnostic err;
@@ -39,6 +41,14 @@ void registerFromLLVMIRTranslation() {
           return {};
         }
         return translateLLVMIRToModule(std::move(llvmModule), context);
+      },
+      [](DialectRegistry &registry) {
+        // Register the DLTI dialect used to express the data layout
+        // specification of the imported module.
+        registry.insert<DLTIDialect>();
+        // Register all dialects that implement the LLVMImportDialectInterface
+        // including the LLVM dialect.
+        registerAllFromLLVMIRTranslations(registry);
       });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt
index 0caeb903f8ab5..616b9fc0fbbaf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt
@@ -1,3 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+  LLVMIRToLLVMTranslation.cpp
+  LLVMToLLVMIRTranslation.cpp
+  )
+
+add_mlir_translation_library(MLIRLLVMIRToLLVMTranslation
+  LLVMIRToLLVMTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRSupport
+  MLIRTargetLLVMIRImport
+  )
+
 add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation
   LLVMToLLVMIRTranslation.cpp
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
new file mode 100644
index 0000000000000..ff6ed967d964e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -0,0 +1,102 @@
+//===- LLVMIRToLLVMTranslation.cpp - Translate LLVM IR to LLVM dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between LLVM IR and the MLIR LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/ModuleImport.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+using namespace mlir::LLVM::detail;
+
+#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
+
+/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
+/// intrinsic. Returns false otherwise.
+static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
+  static const DenseSet<unsigned> convertibleIntrinsics = {
+#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
+  };
+  return convertibleIntrinsics.contains(id);
+}
+
+/// Returns the list of LLVM IR intrinsic identifiers that are convertible to
+/// MLIR LLVM dialect intrinsics.
+static ArrayRef<unsigned> getSupportedIntrinsicsImpl() {
+  static const SmallVector<unsigned> convertibleIntrinsics = {
+#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
+  };
+  return convertibleIntrinsics;
+}
+
+/// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
+/// conversion exits. Returns failure otherwise.
+static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
+                                          llvm::CallInst *inst,
+                                          LLVM::ModuleImport &moduleImport) {
+  llvm::Intrinsic::ID intrinsicID = inst->getIntrinsicID();
+
+  // Check if the intrinsic is convertible to an MLIR dialect counterpart and
+  // copy the arguments to an an LLVM operands array reference for conversion.
+  if (isConvertibleIntrinsic(intrinsicID)) {
+    SmallVector<llvm::Value *> args(inst->args());
+    ArrayRef<llvm::Value *> llvmOperands(args);
+#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
+  }
+
+  return failure();
+}
+
+namespace {
+
+/// Implementation of the dialect interface that converts operations belonging
+/// to the LLVM dialect to LLVM IR.
+class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
+public:
+  using LLVMImportDialectInterface::LLVMImportDialectInterface;
+
+  /// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
+  /// conversion exits. Returns failure otherwise.
+  LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
+                                 LLVM::ModuleImport &moduleImport) const final {
+    return convertIntrinsicImpl(builder, inst, moduleImport);
+  }
+
+  /// Returns the list of LLVM IR intrinsic identifiers that are convertible to
+  /// MLIR LLVM dialect intrinsics.
+  ArrayRef<unsigned> getSupportedIntrinsics() const final {
+    return getSupportedIntrinsicsImpl();
+  }
+};
+} // namespace
+
+void mlir::registerLLVMDialectImport(DialectRegistry &registry) {
+  registry.insert<LLVM::LLVMDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
+    dialect->addInterfaces<LLVMDialectLLVMIRImportInterface>();
+  });
+}
+
+void mlir::registerLLVMDialectImport(MLIRContext &context) {
+  DialectRegistry registry;
+  registerLLVMDialectImport(registry);
+  context.appendDialectRegistry(registry);
+}

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index fbca85c545a48..77828c99398f7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -37,15 +37,6 @@ using namespace mlir::LLVM::detail;
 
 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
 
-/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
-/// intrinsic, or false if no counterpart exists.
-static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
-  static const DenseSet<unsigned> convertibleIntrinsics = {
-#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
-  };
-  return convertibleIntrinsics.contains(id);
-}
-
 // Utility to print an LLVM value as a string for passing to emitError().
 // FIXME: Diagnostic should be able to natively handle types that have
 // operator << (raw_ostream&) defined.
@@ -58,7 +49,7 @@ static std::string diag(llvm::Value &value) {
 
 /// Creates an attribute containing ABI and preferred alignment numbers parsed
 /// a string. The string may be either "abi:preferred" or just "abi". In the
-/// latter case, the prefrred alignment is considered equal to ABI alignment.
+/// latter case, the preferred alignment is considered equal to ABI alignment.
 static DenseIntElementsAttr parseDataLayoutAlignment(MLIRContext &ctx,
                                                      StringRef spec) {
   auto i32 = IntegerType::get(&ctx, 32);
@@ -320,6 +311,7 @@ ModuleImport::ModuleImport(ModuleOp mlirModule,
                            std::unique_ptr<llvm::Module> llvmModule)
     : builder(mlirModule->getContext()), context(mlirModule->getContext()),
       mlirModule(mlirModule), llvmModule(std::move(llvmModule)),
+      iface(mlirModule->getContext()),
       typeTranslator(*mlirModule->getContext()),
       debugImporter(std::make_unique<DebugImporter>(mlirModule->getContext())) {
   builder.setInsertionPointToStart(mlirModule.getBody());
@@ -807,26 +799,20 @@ ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
 }
 
 LogicalResult ModuleImport::convertIntrinsic(OpBuilder &odsBuilder,
-                                             llvm::CallInst *inst,
-                                             llvm::Intrinsic::ID intrinsicID) {
-  Location loc = translateLoc(inst->getDebugLoc());
-
-  // Check if the intrinsic is convertible to an MLIR dialect counterpart and
-  // copy the arguments to an an LLVM operands array reference for conversion.
-  if (isConvertibleIntrinsic(intrinsicID)) {
-    SmallVector<llvm::Value *> args(inst->args());
-    ArrayRef<llvm::Value *> llvmOperands(args);
-#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
-  }
+                                             llvm::CallInst *inst) {
+  if (succeeded(iface.convertIntrinsic(builder, inst, *this)))
+    return success();
 
+  Location loc = translateLoc(inst->getDebugLoc());
   return emitError(loc) << "unhandled intrinsic " << diag(*inst);
 }
 
-LogicalResult ModuleImport::convertOperation(OpBuilder &odsBuilder,
-                                             llvm::Instruction *inst) {
+LogicalResult ModuleImport::convertInstruction(OpBuilder &odsBuilder,
+                                               llvm::Instruction *inst) {
   // Copy the operands to an LLVM operands array reference for conversion.
   SmallVector<llvm::Value *> operands(inst->operands());
   ArrayRef<llvm::Value *> llvmOperands(operands);
+  ModuleImport &moduleImport = *this;
 
   // Convert all instructions that provide an MLIR builder.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
@@ -1006,11 +992,11 @@ LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
   if (auto *callInst = dyn_cast<llvm::CallInst>(inst)) {
     llvm::Function *callee = callInst->getCalledFunction();
     if (callee && callee->isIntrinsic())
-      return convertIntrinsic(builder, callInst, callInst->getIntrinsicID());
+      return convertIntrinsic(builder, callInst);
   }
 
   // Convert all remaining LLVM instructions to MLIR operations.
-  return convertOperation(builder, inst);
+  return convertInstruction(builder, inst);
 }
 
 FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) {
@@ -1049,7 +1035,8 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
 
   auto functionType =
       convertType(func->getFunctionType()).dyn_cast<LLVMFunctionType>();
-  if (func->isIntrinsic() && isConvertibleIntrinsic(func->getIntrinsicID()))
+  if (func->isIntrinsic() &&
+      iface.isConvertibleIntrinsic(func->getIntrinsicID()))
     return success();
 
   bool dsoLocal = func->hasLocalLinkage();
@@ -1151,8 +1138,17 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
 OwningOpRef<ModuleOp>
 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
                               MLIRContext *context) {
-  context->loadDialect<LLVMDialect>();
-  context->loadDialect<DLTIDialect>();
+  // Preload all registered dialects to allow the import to iterate the
+  // registered LLVMImportDialectInterface implementations and query the
+  // supported LLVM IR constructs before starting the translation. Assumes the
+  // LLVM and DLTI dialects that convert the core LLVM IR constructs have been
+  // registered before.
+  assert(llvm::is_contained(context->getAvailableDialects(),
+                            LLVMDialect::getDialectNamespace()));
+  assert(llvm::is_contained(context->getAvailableDialects(),
+                            DLTIDialect::getDialectNamespace()));
+  context->loadAllAvailableDialects();
+
   OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
       StringAttr::get(context, llvmModule->getSourceFileName()), /*line=*/0,
       /*column=*/0)));
@@ -1166,6 +1162,8 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
   module.get()->setAttr(DLTIDialect::kDataLayoutAttrName, dlSpec);
 
   ModuleImport moduleImport(module.get(), std::move(llvmModule));
+  if (failed(moduleImport.initializeImportInterface()))
+    return {};
   if (failed(moduleImport.convertGlobals()))
     return {};
   if (failed(moduleImport.convertFunctions()))

diff  --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp
index afeaed52e329e..ca8ea1ca04aaa 100644
--- a/mlir/lib/Tools/mlir-translate/Translation.cpp
+++ b/mlir/lib/Tools/mlir-translate/Translation.cpp
@@ -73,10 +73,16 @@ TranslateRegistration::TranslateRegistration(
 // Puts `function` into the to-MLIR translation registry unless there is already
 // a function registered for the same name.
 static void registerTranslateToMLIRFunction(
-    StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
+    StringRef name, StringRef description,
+    const DialectRegistrationFunction &dialectRegistration,
+    Optional<llvm::Align> inputAlignment,
     const TranslateSourceMgrToMLIRFunction &function) {
-  auto wrappedFn = [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
-                              raw_ostream &output, MLIRContext *context) {
+  auto wrappedFn = [function, dialectRegistration](
+                       const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+                       raw_ostream &output, MLIRContext *context) {
+    DialectRegistry registry;
+    dialectRegistration(registry);
+    context->appendDialectRegistry(registry);
     OwningOpRef<Operation *> op = function(sourceMgr, context);
     if (!op || failed(verify(*op)))
       return failure();
@@ -89,15 +95,18 @@ static void registerTranslateToMLIRFunction(
 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateSourceMgrToMLIRFunction &function,
+    const DialectRegistrationFunction &dialectRegistration,
     Optional<llvm::Align> inputAlignment) {
-  registerTranslateToMLIRFunction(name, description, inputAlignment, function);
+  registerTranslateToMLIRFunction(name, description, dialectRegistration,
+                                  inputAlignment, function);
 }
 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateRawSourceMgrToMLIRFunction &function,
+    const DialectRegistrationFunction &dialectRegistration,
     Optional<llvm::Align> inputAlignment) {
   registerTranslateToMLIRFunction(
-      name, description, inputAlignment,
+      name, description, dialectRegistration, inputAlignment,
       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
                  MLIRContext *ctx) { return function(*sourceMgr, ctx); });
 }
@@ -106,9 +115,10 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateStringRefToMLIRFunction &function,
+    const DialectRegistrationFunction &dialectRegistration,
     Optional<llvm::Align> inputAlignment) {
   registerTranslateToMLIRFunction(
-      name, description, inputAlignment,
+      name, description, dialectRegistration, inputAlignment,
       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
                  MLIRContext *ctx) {
         const llvm::MemoryBuffer *buffer =
@@ -124,7 +134,7 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateFromMLIRFunction &function,
-    const std::function<void(DialectRegistry &)> &dialectRegistration) {
+    const DialectRegistrationFunction &dialectRegistration) {
   registerTranslation(
       name, description, /*inputAlignment=*/std::nullopt,
       [function,

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 2297c55421342..4e59c8b44422a 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -246,11 +246,11 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
         if (isVariadicOperandName(op, name)) {
           as << formatv(
               "FailureOr<SmallVector<Value>> _llvmir_gen_operand_{0} = "
-              "convertValues(llvmOperands.drop_front({1}));\n",
+              "moduleImport.convertValues(llvmOperands.drop_front({1}));\n",
               name, idx);
         } else {
           as << formatv("FailureOr<Value> _llvmir_gen_operand_{0} = "
-                        "convertValue(llvmOperands[{1}]);\n",
+                        "moduleImport.convertValue(llvmOperands[{1}]);\n",
                         name, idx);
         }
         as << formatv("if (failed(_llvmir_gen_operand_{0}))\n"
@@ -261,15 +261,15 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
     } else if (isResultName(op, name)) {
       if (op.getNumResults() != 1)
         return emitError(record, "expected op to have one result");
-      bs << "mapValue(inst)";
+      bs << "moduleImport.mapValue(inst)";
     } else if (name == "_int_attr") {
-      bs << "matchIntegerAttr";
+      bs << "moduleImport.matchIntegerAttr";
     } else if (name == "_var_attr") {
-      bs << "matchLocalVariableAttr";
+      bs << "moduleImport.matchLocalVariableAttr";
     } else if (name == "_resultType") {
-      bs << "convertType(inst->getType())";
+      bs << "moduleImport.convertType(inst->getType())";
     } else if (name == "_location") {
-      bs << "translateLoc(inst->getDebugLoc())";
+      bs << "moduleImport.translateLoc(inst->getDebugLoc())";
     } else if (name == "_builder") {
       bs << "odsBuilder";
     } else if (name == "_qualCppClassName") {


        


More information about the Mlir-commits mailing list