[Mlir-commits] [mlir] [mlir] Add `convertInstruction` and `getSupportedInstructions` to `LLVMImportInterface` (PR #86799)

Fabian Mora llvmlistbot at llvm.org
Wed Mar 27 07:41:51 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86799

>From 5f71a95de1609a40592c3ca2a881bd2f7bb74711 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 12:28:06 +0000
Subject: [PATCH 1/3] [mlir] Add `convertInstruction` and
 `getSupportedInstructions` to `LLVMImportInterface`

This patch adds the `convertInstruction` and `getSupportedInstructions` to
`LLVMImportInterface`, allowing any non-LLVM dialect to specify how to import
LLVM IR instructions.

This patch is necessary for https://github.com/llvm/llvm-project/pull/73057
---
 .../mlir/Target/LLVMIR/LLVMImportInterface.h  | 53 +++++++++++++++++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  8 ++-
 2 files changed, 59 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 9f8da83ae9c205..1bd81fcd9400cb 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -52,6 +52,15 @@ class LLVMImportDialectInterface
     return failure();
   }
 
+  /// Hook for derived dialect interfaces to implement the import of
+  /// instructions into MLIR.
+  virtual LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const {
+    return failure();
+  }
+
   /// Hook for derived dialect interfaces to implement the import of metadata
   /// into MLIR. Attaches the converted metadata kind and node to the provided
   /// operation.
@@ -66,6 +75,11 @@ class LLVMImportDialectInterface
   /// returns the list of supported intrinsic identifiers.
   virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }
 
+  /// Hook for derived dialect interfaces to publish the supported instructions.
+  /// As every LLVM IR instructions has a unique integer identifier, the
+  /// function returns the list of supported instructions identifiers.
+  virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
+
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
   /// function returns the list of supported metadata identifiers.
@@ -100,9 +114,27 @@ class LLVMImportInterface
                           *it, iface.getDialect()->getNamespace(),
                           intrinsicToDialect.lookup(*it)->getNamespace()));
       }
+      const auto *instIt =
+          llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
+            return instructionToDialect.count(id);
+          });
+      if (instIt != iface.getSupportedInstructions().end()) {
+        return emitError(
+            UnknownLoc::get(iface.getContext()),
+            llvm::formatv(
+                "expected unique conversion for instruction ({0}), but "
+                "got conflicting {1} and {2} conversions",
+                *it, iface.getDialect()->getNamespace(),
+                instructionToDialect.lookup(*it)
+                    ->getDialect()
+                    ->getNamespace()));
+      }
       // Add a mapping for all supported intrinsic identifiers.
       for (unsigned id : iface.getSupportedIntrinsics())
         intrinsicToDialect[id] = iface.getDialect();
+      // Add a mapping for all supported instruction identifiers.
+      for (unsigned id : iface.getSupportedInstructions())
+        instructionToDialect[id] = &iface;
       // Add a mapping for all supported metadata kinds.
       for (unsigned kind : iface.getSupportedMetadata())
         metadataToDialect[kind].push_back(iface.getDialect());
@@ -132,6 +164,26 @@ class LLVMImportInterface
     return intrinsicToDialect.count(id);
   }
 
+  /// Converts the LLVM instruction to an MLIR operation if a conversion exists.
+  /// Returns failure otherwise.
+  LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                                   ArrayRef<llvm::Value *> llvmOperands,
+                                   LLVM::ModuleImport &moduleImport) const {
+    // Lookup the dialect interface for the given instruction.
+    const LLVMImportDialectInterface *iface =
+        instructionToDialect.lookup(inst->getOpcode());
+    if (!iface)
+      return failure();
+
+    return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
+  }
+
+  /// Returns true if the given LLVM IR instruction is convertible to an MLIR
+  /// operation.
+  bool isConvertibleInstruction(unsigned id) {
+    return instructionToDialect.count(id);
+  }
+
   /// Attaches the given LLVM metadata to the imported operation if a conversion
   /// to one or more MLIR dialect attributes exists and succeeds. Returns
   /// success if at least one of the conversions is successful and failure if
@@ -166,6 +218,7 @@ class LLVMImportInterface
 
 private:
   DenseMap<unsigned, Dialect *> intrinsicToDialect;
+  DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
   DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
 };
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 6e70d52fa760b6..af419261c8919c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -123,13 +123,17 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
 /// access to the private module import methods.
 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
                                             llvm::Instruction *inst,
-                                            ModuleImport &moduleImport) {
+                                            ModuleImport &moduleImport,
+                                            LLVMImportInterface &importIface) {
   // Copy the operands to an LLVM operands array reference for conversion.
   SmallVector<llvm::Value *> operands(inst->operands());
   ArrayRef<llvm::Value *> llvmOperands(operands);
 
   // Convert all instructions that provide an MLIR builder.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
+  if (importIface.isConvertibleInstruction(inst->getOpcode()))
+    return importIface.convertInstruction(odsBuilder, inst, llvmOperands,
+                                          moduleImport);
   return failure();
 }
 
@@ -1596,7 +1600,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
   }
 
   // Convert all instructions that have an mlirBuilder.
-  if (succeeded(convertInstructionImpl(builder, inst, *this)))
+  if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
     return success();
 
   return emitError(loc) << "unhandled instruction: " << diag(*inst);

>From 5118415a746e46b666567ef172ff1c00018a98e3 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 13:23:24 +0000
Subject: [PATCH 2/3] Add test

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |   2 +-
 mlir/test/Target/LLVMIR/Import/test.ll        |  11 ++
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |  18 +++
 .../Test/TestFromLLVMIRTranslation.cpp        | 113 ++++++++++++++++++
 mlir/tools/mlir-translate/mlir-translate.cpp  |   2 +
 5 files changed, 145 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/test.ll
 create mode 100644 mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index af419261c8919c..3320c6cd3ab24d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -130,10 +130,10 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
   ArrayRef<llvm::Value *> llvmOperands(operands);
 
   // Convert all instructions that provide an MLIR builder.
-#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   if (importIface.isConvertibleInstruction(inst->getOpcode()))
     return importIface.convertInstruction(odsBuilder, inst, llvmOperands,
                                           moduleImport);
+#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
 
diff --git a/mlir/test/Target/LLVMIR/Import/test.ll b/mlir/test/Target/LLVMIR/Import/test.ll
new file mode 100644
index 00000000000000..6f3dd1acf9586d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/test.ll
@@ -0,0 +1,11 @@
+; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s
+
+; CHECK-LABEL: @custom_load
+; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+define double @custom_load(ptr %ptr) {
+  ; CHECK:  %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
+  ; CHECK:  %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
+  %1 = load double, ptr %ptr
+  ; CHECK:   llvm.return %[[TEST]] : f64
+  ret double %1
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index b82b1631eead59..47ddcf6524748c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   TestDialect.cpp
   TestPatterns.cpp
   TestTraits.cpp
+  TestFromLLVMIRTranslation.cpp
   TestToLLVMIRTranslation.cpp
 )
 
@@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect
   MLIRTransforms
 )
 
+add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
+  TestFromLLVMIRTranslation.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRTestDialect
+  MLIRSupport
+  MLIRTargetLLVMIRImport
+  MLIRLLVMIRToLLVMTranslation
+)
+
 add_mlir_translation_library(MLIRTestToLLVMIRTranslation
   TestToLLVMIRTranslation.cpp
 
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..1ecbb5eb445060
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -0,0 +1,113 @@
+//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between LLVM IR and the MLIR Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Target/LLVMIR/Import.h"
+#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+inline ArrayRef<unsigned> getSupportedInstructionsImpl() {
+  static unsigned instructions[] = {llvm::Instruction::Load};
+  return instructions;
+}
+
+LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
+                          ArrayRef<llvm::Value *> llvmOperands,
+                          LLVM::ModuleImport &moduleImport) {
+  FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
+  if (failed(addr))
+    return failure();
+  auto *loadInst = cast<llvm::LoadInst>(inst);
+  unsigned alignment = loadInst->getAlign().value();
+  // Create the LoadOp
+  Value loadOp = builder.create<LLVM::LoadOp>(
+      moduleImport.translateLoc(inst->getDebugLoc()),
+      moduleImport.convertType(inst->getType()), *addr);
+  moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
+      loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
+  return success();
+}
+
+class TestDialectLLVMImportDialectInterface
+    : public LLVMImportDialectInterface {
+public:
+  using LLVMImportDialectInterface::LLVMImportDialectInterface;
+
+  LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const override {
+    switch (inst->getOpcode()) {
+    case llvm::Instruction::Load:
+      return convertLoad(builder, inst, llvmOperands, moduleImport);
+    default:
+      break;
+    }
+    return failure();
+  }
+
+  ArrayRef<unsigned> getSupportedInstructions() const override {
+    return getSupportedInstructionsImpl();
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestFromLLVMIR() {
+  TranslateToMLIRRegistration registration(
+      "test-import-llvmir", "test dialect from LLVM IR",
+      [](llvm::SourceMgr &sourceMgr,
+         MLIRContext *context) -> OwningOpRef<Operation *> {
+        llvm::SMDiagnostic err;
+        llvm::LLVMContext llvmContext;
+        std::unique_ptr<llvm::Module> llvmModule =
+            llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
+                          err, llvmContext);
+        if (!llvmModule) {
+          std::string errStr;
+          llvm::raw_string_ostream errStream(errStr);
+          err.print(/*ProgName=*/"", errStream);
+          emitError(UnknownLoc::get(context)) << errStream.str();
+          return {};
+        }
+        if (llvm::verifyModule(*llvmModule, &llvm::errs()))
+          return nullptr;
+
+        return translateLLVMIRToModule(std::move(llvmModule), context, false);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<DLTIDialect>();
+        registry.insert<test::TestDialect>();
+        registerLLVMDialectImport(registry);
+        registry.addExtension(
+            +[](MLIRContext *ctx, test::TestDialect *dialect) {
+              dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
+            });
+      });
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 4f9661c058c2d3..309def888a073c 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV();
 void registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
 void registerTestToLLVMIR();
+void registerTestFromLLVMIR();
 #endif
 } // namespace mlir
 
@@ -31,6 +32,7 @@ static void registerTestTranslations() {
   registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
   registerTestToLLVMIR();
+  registerTestFromLLVMIR();
 #endif
 }
 

>From 492d257cf84fbbc9731a0f37216716b2a7875c98 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 14:37:00 +0000
Subject: [PATCH 3/3] address reviewer comments

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp                | 10 ++++++----
 mlir/test/Target/LLVMIR/Import/test.ll                 |  2 +-
 .../lib/Dialect/Test/TestFromLLVMIRTranslation.cpp     | 10 +++++-----
 3 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 3320c6cd3ab24d..af998b99d511f0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -124,15 +124,17 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
                                             llvm::Instruction *inst,
                                             ModuleImport &moduleImport,
-                                            LLVMImportInterface &importIface) {
+                                            LLVMImportInterface &iface) {
   // Copy the operands to an LLVM operands array reference for conversion.
   SmallVector<llvm::Value *> operands(inst->operands());
   ArrayRef<llvm::Value *> llvmOperands(operands);
 
   // Convert all instructions that provide an MLIR builder.
-  if (importIface.isConvertibleInstruction(inst->getOpcode()))
-    return importIface.convertInstruction(odsBuilder, inst, llvmOperands,
-                                          moduleImport);
+  if (iface.isConvertibleInstruction(inst->getOpcode()))
+    return iface.convertInstruction(odsBuilder, inst, llvmOperands,
+                                    moduleImport);
+    // TODO: Implement the `convertInstruction` hooks in the
+    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
diff --git a/mlir/test/Target/LLVMIR/Import/test.ll b/mlir/test/Target/LLVMIR/Import/test.ll
index 6f3dd1acf9586d..a3165d60201047 100644
--- a/mlir/test/Target/LLVMIR/Import/test.ll
+++ b/mlir/test/Target/LLVMIR/Import/test.ll
@@ -1,7 +1,7 @@
 ; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s
 
 ; CHECK-LABEL: @custom_load
-; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[PTR:[[:alnum:]]+]]
 define double @custom_load(ptr %ptr) {
   ; CHECK:  %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
   ; CHECK:  %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
index 1ecbb5eb445060..86197299742b73 100644
--- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -30,15 +30,14 @@
 using namespace mlir;
 using namespace test;
 
-namespace {
-inline ArrayRef<unsigned> getSupportedInstructionsImpl() {
+static ArrayRef<unsigned> getSupportedInstructionsImpl() {
   static unsigned instructions[] = {llvm::Instruction::Load};
   return instructions;
 }
 
-LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
-                          ArrayRef<llvm::Value *> llvmOperands,
-                          LLVM::ModuleImport &moduleImport) {
+static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
+                                 ArrayRef<llvm::Value *> llvmOperands,
+                                 LLVM::ModuleImport &moduleImport) {
   FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
   if (failed(addr))
     return failure();
@@ -53,6 +52,7 @@ LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
   return success();
 }
 
+namespace {
 class TestDialectLLVMImportDialectInterface
     : public LLVMImportDialectInterface {
 public:



More information about the Mlir-commits mailing list