[Mlir-commits] [mlir] a2c4b7c - [mlir] Add `convertInstruction` and `getSupportedInstructions` to `LLVMImportInterface` (#86799)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 6 23:46:25 PDT 2024


Author: Fabian Mora
Date: 2024-04-07T08:46:21+02:00
New Revision: a2c4b7c8e2740a83f141dcf06cf50359588190b9

URL: https://github.com/llvm/llvm-project/commit/a2c4b7c8e2740a83f141dcf06cf50359588190b9
DIFF: https://github.com/llvm/llvm-project/commit/a2c4b7c8e2740a83f141dcf06cf50359588190b9.diff

LOG: [mlir] Add `convertInstruction` and `getSupportedInstructions` to `LLVMImportInterface` (#86799)

This patch adds the `convertInstruction` and `getSupportedInstructions`
to `LLVMImportInterface`, allowing any non-LLVM dialect to specify how
to import LLVM IR instructions and overriding the default import of LLVM instructions.

Added: 
    mlir/test/Target/LLVMIR/Import/test.ll
    mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp

Modified: 
    mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/tools/mlir-translate/mlir-translate.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 9f8da83ae9c205..86bcd580c1b449 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -52,6 +52,15 @@ class LLVMImportDialectInterface
     return failure();
   }
 
+  /// Hook for derived dialect interfaces to implement the import of
+  /// instructions into MLIR.
+  virtual LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const {
+    return failure();
+  }
+
   /// Hook for derived dialect interfaces to implement the import of metadata
   /// into MLIR. Attaches the converted metadata kind and node to the provided
   /// operation.
@@ -66,6 +75,14 @@ class LLVMImportDialectInterface
   /// returns the list of supported intrinsic identifiers.
   virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }
 
+  /// Hook for derived dialect interfaces to publish the supported instructions.
+  /// As every LLVM IR instruction has a unique integer identifier, the function
+  /// returns the list of supported instruction identifiers. These identifiers
+  /// will then be used to match LLVM instructions to the appropriate import
+  /// interface and `convertInstruction` method. It is an error to have multiple
+  /// interfaces overriding the same instruction.
+  virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
+
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
   /// function returns the list of supported metadata identifiers.
@@ -88,21 +105,40 @@ class LLVMImportInterface
   LogicalResult initializeImport() {
     for (const LLVMImportDialectInterface &iface : *this) {
       // Verify the supported intrinsics have not been mapped before.
-      const auto *it =
+      const auto *intrinsicIt =
           llvm::find_if(iface.getSupportedIntrinsics(), [&](unsigned id) {
             return intrinsicToDialect.count(id);
           });
-      if (it != iface.getSupportedIntrinsics().end()) {
+      if (intrinsicIt != iface.getSupportedIntrinsics().end()) {
+        return emitError(
+            UnknownLoc::get(iface.getContext()),
+            llvm::formatv(
+                "expected unique conversion for intrinsic ({0}), but "
+                "got conflicting {1} and {2} conversions",
+                *intrinsicIt, iface.getDialect()->getNamespace(),
+                intrinsicToDialect.lookup(*intrinsicIt)->getNamespace()));
+      }
+      const auto *instructionIt =
+          llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
+            return instructionToDialect.count(id);
+          });
+      if (instructionIt != iface.getSupportedInstructions().end()) {
         return emitError(
             UnknownLoc::get(iface.getContext()),
-            llvm::formatv("expected unique conversion for intrinsic ({0}), but "
-                          "got conflicting {1} and {2} conversions",
-                          *it, iface.getDialect()->getNamespace(),
-                          intrinsicToDialect.lookup(*it)->getNamespace()));
+            llvm::formatv(
+                "expected unique conversion for instruction ({0}), but "
+                "got conflicting {1} and {2} conversions",
+                *intrinsicIt, iface.getDialect()->getNamespace(),
+                instructionToDialect.lookup(*intrinsicIt)
+                    ->getDialect()
+                    ->getNamespace()));
       }
       // Add a mapping for all supported intrinsic identifiers.
       for (unsigned id : iface.getSupportedIntrinsics())
         intrinsicToDialect[id] = iface.getDialect();
+      // Add a mapping for all supported instruction identifiers.
+      for (unsigned id : iface.getSupportedInstructions())
+        instructionToDialect[id] = &iface;
       // Add a mapping for all supported metadata kinds.
       for (unsigned kind : iface.getSupportedMetadata())
         metadataToDialect[kind].push_back(iface.getDialect());
@@ -132,6 +168,26 @@ class LLVMImportInterface
     return intrinsicToDialect.count(id);
   }
 
+  /// Converts the LLVM instruction to an MLIR operation if a conversion exists.
+  /// Returns failure otherwise.
+  LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                                   ArrayRef<llvm::Value *> llvmOperands,
+                                   LLVM::ModuleImport &moduleImport) const {
+    // Lookup the dialect interface for the given instruction.
+    const LLVMImportDialectInterface *iface =
+        instructionToDialect.lookup(inst->getOpcode());
+    if (!iface)
+      return failure();
+
+    return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
+  }
+
+  /// Returns true if the given LLVM IR instruction is convertible to an MLIR
+  /// operation.
+  bool isConvertibleInstruction(unsigned id) {
+    return instructionToDialect.count(id);
+  }
+
   /// Attaches the given LLVM metadata to the imported operation if a conversion
   /// to one or more MLIR dialect attributes exists and succeeds. Returns
   /// success if at least one of the conversions is successful and failure if
@@ -166,6 +222,7 @@ class LLVMImportInterface
 
 private:
   DenseMap<unsigned, Dialect *> intrinsicToDialect;
+  DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
   DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
 };
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 6e70d52fa760b6..af998b99d511f0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -123,12 +123,18 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
 /// access to the private module import methods.
 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
                                             llvm::Instruction *inst,
-                                            ModuleImport &moduleImport) {
+                                            ModuleImport &moduleImport,
+                                            LLVMImportInterface &iface) {
   // Copy the operands to an LLVM operands array reference for conversion.
   SmallVector<llvm::Value *> operands(inst->operands());
   ArrayRef<llvm::Value *> llvmOperands(operands);
 
   // Convert all instructions that provide an MLIR builder.
+  if (iface.isConvertibleInstruction(inst->getOpcode()))
+    return iface.convertInstruction(odsBuilder, inst, llvmOperands,
+                                    moduleImport);
+    // TODO: Implement the `convertInstruction` hooks in the
+    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
@@ -1596,7 +1602,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
   }
 
   // Convert all instructions that have an mlirBuilder.
-  if (succeeded(convertInstructionImpl(builder, inst, *this)))
+  if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
     return success();
 
   return emitError(loc) << "unhandled instruction: " << diag(*inst);

diff  --git a/mlir/test/Target/LLVMIR/Import/test.ll b/mlir/test/Target/LLVMIR/Import/test.ll
new file mode 100644
index 00000000000000..a3165d60201047
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/test.ll
@@ -0,0 +1,11 @@
+; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s
+
+; CHECK-LABEL: @custom_load
+; CHECK-SAME:  %[[PTR:[[:alnum:]]+]]
+define double @custom_load(ptr %ptr) {
+  ; CHECK:  %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
+  ; CHECK:  %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
+  %1 = load double, ptr %ptr
+  ; CHECK:   llvm.return %[[TEST]] : f64
+  ret double %1
+}

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index b82b1631eead59..47ddcf6524748c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   TestDialect.cpp
   TestPatterns.cpp
   TestTraits.cpp
+  TestFromLLVMIRTranslation.cpp
   TestToLLVMIRTranslation.cpp
 )
 
@@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect
   MLIRTransforms
 )
 
+add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
+  TestFromLLVMIRTranslation.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRTestDialect
+  MLIRSupport
+  MLIRTargetLLVMIRImport
+  MLIRLLVMIRToLLVMTranslation
+)
+
 add_mlir_translation_library(MLIRTestToLLVMIRTranslation
   TestToLLVMIRTranslation.cpp
 

diff  --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..3673d62bea2c94
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -0,0 +1,111 @@
+//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between LLVM IR and the MLIR Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Target/LLVMIR/Import.h"
+#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace test;
+
+static ArrayRef<unsigned> getSupportedInstructionsImpl() {
+  static unsigned instructions[] = {llvm::Instruction::Load};
+  return instructions;
+}
+
+static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
+                                 ArrayRef<llvm::Value *> llvmOperands,
+                                 LLVM::ModuleImport &moduleImport) {
+  FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
+  if (failed(addr))
+    return failure();
+  // Create the LoadOp
+  Value loadOp = builder.create<LLVM::LoadOp>(
+      moduleImport.translateLoc(inst->getDebugLoc()),
+      moduleImport.convertType(inst->getType()), *addr);
+  moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
+      loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
+  return success();
+}
+
+namespace {
+class TestDialectLLVMImportDialectInterface
+    : public LLVMImportDialectInterface {
+public:
+  using LLVMImportDialectInterface::LLVMImportDialectInterface;
+
+  LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const override {
+    switch (inst->getOpcode()) {
+    case llvm::Instruction::Load:
+      return convertLoad(builder, inst, llvmOperands, moduleImport);
+    default:
+      break;
+    }
+    return failure();
+  }
+
+  ArrayRef<unsigned> getSupportedInstructions() const override {
+    return getSupportedInstructionsImpl();
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestFromLLVMIR() {
+  TranslateToMLIRRegistration registration(
+      "test-import-llvmir", "test dialect from LLVM IR",
+      [](llvm::SourceMgr &sourceMgr,
+         MLIRContext *context) -> OwningOpRef<Operation *> {
+        llvm::SMDiagnostic err;
+        llvm::LLVMContext llvmContext;
+        std::unique_ptr<llvm::Module> llvmModule =
+            llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
+                          err, llvmContext);
+        if (!llvmModule) {
+          std::string errStr;
+          llvm::raw_string_ostream errStream(errStr);
+          err.print(/*ProgName=*/"", errStream);
+          emitError(UnknownLoc::get(context)) << errStream.str();
+          return {};
+        }
+        if (llvm::verifyModule(*llvmModule, &llvm::errs()))
+          return nullptr;
+
+        return translateLLVMIRToModule(std::move(llvmModule), context, false);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<DLTIDialect>();
+        registry.insert<test::TestDialect>();
+        registerLLVMDialectImport(registry);
+        registry.addExtension(
+            +[](MLIRContext *ctx, test::TestDialect *dialect) {
+              dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
+            });
+      });
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 4f9661c058c2d3..309def888a073c 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV();
 void registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
 void registerTestToLLVMIR();
+void registerTestFromLLVMIR();
 #endif
 } // namespace mlir
 
@@ -31,6 +32,7 @@ static void registerTestTranslations() {
   registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
   registerTestToLLVMIR();
+  registerTestFromLLVMIR();
 #endif
 }
 


        


More information about the Mlir-commits mailing list