[Mlir-commits] [mlir] 4529797 - Add a generic "convert-to-llvm" pass delegating to an interface

Mehdi Amini llvmlistbot at llvm.org
Mon Aug 7 18:46:27 PDT 2023


Author: Mehdi Amini
Date: 2023-08-07T18:46:08-07:00
New Revision: 4529797a9d7c19105815cc9a3f19571b5fca2d06

URL: https://github.com/llvm/llvm-project/commit/4529797a9d7c19105815cc9a3f19571b5fca2d06
DIFF: https://github.com/llvm/llvm-project/commit/4529797a9d7c19105815cc9a3f19571b5fca2d06.diff

LOG: Add a generic "convert-to-llvm" pass delegating to an interface

The multiple -convert-XXX-to-llvm passes are really nice testing tools for
individual dialects, but the expectation is that a proper conversion should
assemble the conversion patterns using `populateXXXToLLVMConversionPatterns()
APIs. However most customers just chain the conversion passes by convenience.

This pass makes it composable more transparently to assemble the required
patterns for conversion to LLVM dialect by using an interface.
The Pass will scan the input and collect all the dialect present, and for
those who implement the `ConvertToLLVMPatternInterface` it will use it to
populate the conversion pattern, and possible the conversion target.

Since these conversions can involve intermediate dialects, or target other
dialects than LLVM (for example AVX or NVVM), this pass can't statically
declare the required `getDependentDialects()` before the pass pipeline
begins. This is worked around by using an extension in the dialectRegistry
that will be invoked for every new loaded dialects in the context. This
allows to lookup the interface ahead of time and use it to query the
dependent dialects.

Differential Revision: https://reviews.llvm.org/D157183

Added: 
    mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
    mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
    mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
    mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
    mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp

Modified: 
    mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/IR/DialectRegistry.h
    mlir/include/mlir/InitAllExtensions.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/IR/Dialect.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
new file mode 100644
index 00000000000000..8841c38deafecb
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -0,0 +1,54 @@
+//===- ToLLVMInterface.h - Conversion to LLVM iface ---*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
+#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class ConversionTarget;
+class LLVMTypeConverter;
+class MLIRContext;
+class Operation;
+class RewritePatternSet;
+
+/// Base class for dialect interfaces providing translation to LLVM IR.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on LLVM IR.
+/// The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class ConvertToLLVMPatternInterface
+    : public DialectInterface::Base<ConvertToLLVMPatternInterface> {
+public:
+  ConvertToLLVMPatternInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Hook for derived dialect interface to load the dialects they
+  /// target. The LLVMDialect is implicitly already loaded, but this
+  /// method allows to load other intermediate dialects used in the
+  /// conversion, or target dialects like NVVM for example.
+  virtual void loadDependentDialects(MLIRContext *context) const {}
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  virtual void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, RewritePatternSet &patterns) const = 0;
+};
+
+/// Recursively walk the IR and collect all dialects implementing the interface,
+/// and populate the conversion patterns.
+void populateConversionTargetFromOperation(Operation *op,
+                                           ConversionTarget &target,
+                                           RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H

diff  --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
new file mode 100644
index 00000000000000..f685e6c7c4bf07
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
@@ -0,0 +1,23 @@
+//===- ToLLVMPass.h - Conversion to LLVM pass ---*- C++ -*-===================//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
+#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+/// Create a pass that performs dialect conversion to LLVM  for all dialects
+/// implementing `ConvertToLLVMPatternInterface`.
+std::unique_ptr<Pass> createConvertToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H

diff  --git a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
index 383ef61f21d705..ee93d73358ef7f 100644
--- a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
+++ b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
@@ -11,7 +11,7 @@
 #include <memory>
 
 namespace mlir {
-
+class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class Pass;
@@ -21,6 +21,8 @@ class Pass;
 
 void populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns);
 
+void registerConvertNVVMToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 014e976586af4c..7027a25b20db3c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -24,6 +24,7 @@
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9608d771a5dd52..a2b50febb05146 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -11,6 +11,20 @@
 
 include "mlir/Pass/PassBase.td"
 
+
+//===----------------------------------------------------------------------===//
+// ToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
+  let summary = "Convert to LLVM via dialect interfaces found in the input IR";
+  let description = [{
+    This is a generic pass to convert to LLVM, it uses the
+    `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
+    the injection of conversion patterns.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AffineToStandard
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index 7266a5de708b98..b49bdc91536ad9 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -44,7 +44,8 @@ class DialectExtensionBase {
   virtual ~DialectExtensionBase();
 
   /// Return the dialects that our required by this extension to be loaded
-  /// before applying.
+  /// before applying. If empty then the extension is invoked for every loaded
+  /// dialect indepently.
   ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
 
   /// Apply this extension to the given context and the required dialects.
@@ -55,12 +56,11 @@ class DialectExtensionBase {
   virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
 
 protected:
-  /// Initialize the extension with a set of required dialects. Note that there
-  /// should always be at least one affected dialect.
+  /// Initialize the extension with a set of required dialects.
+  /// If the list is empty, the extension is invoked for every loaded dialect
+  /// independently.
   DialectExtensionBase(ArrayRef<StringRef> dialectNames)
-      : dialectNames(dialectNames.begin(), dialectNames.end()) {
-    assert(!dialectNames.empty() && "expected at least one affected dialect");
-  }
+      : dialectNames(dialectNames.begin(), dialectNames.end()) {}
 
 private:
   /// The names of the dialects affected by this extension.

diff  --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index e09ec2ff20d893..45e360e8666ec2 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INITALLEXTENSIONS_H_
 #define MLIR_INITALLEXTENSIONS_H_
 
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 
 #include <cstdlib>
@@ -27,6 +28,7 @@ namespace mlir {
 /// pipelines and transformations you are using.
 inline void registerAllExtensions(DialectRegistry &registry) {
   func::registerAllExtensions(registry);
+  registerConvertNVVMToLLVMInterface(registry);
 }
 
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9fabeae0710383..01c45fa63157eb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
 add_subdirectory(ControlFlowToLLVM)
 add_subdirectory(ControlFlowToSPIRV)
+add_subdirectory(ConvertToLLVM)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(FuncToSPIRV)
 add_subdirectory(GPUCommon)

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..2260e343be0920
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -0,0 +1,27 @@
+set(LLVM_OPTIONAL_SOURCES
+  ConvertToLLVMPass.cpp
+  ToLLVMInterface.cpp
+)
+
+add_mlir_conversion_library(MLIRConvertToLLVMInterface
+  ToLLVMInterface.cpp
+
+  DEPENDS
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSupport
+)
+
+add_mlir_conversion_library(MLIRConvertToLLVMPass
+  ConvertToLLVMPass.cpp
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRConvertToLLVMInterface
+  MLIRPass
+  MLIRIR
+  MLIRSupport
+  )

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
new file mode 100644
index 00000000000000..9766e847c7f3b5
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -0,0 +1,101 @@
+//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <memory>
+
+#define DEBUG_TYPE "convert-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// This DialectExtension can be attached to the context, which will invoke the
+/// `apply()` method for every loaded dialect. If a dialect implements the
+/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
+/// through the interface. This extension is loaded in the context before
+/// starting a pass pipeline that involves dialect conversion to LLVM.
+class LoadDependentDialectExtension : public DialectExtensionBase {
+public:
+  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
+
+  void apply(MLIRContext *context,
+             MutableArrayRef<Dialect *> dialects) const final {
+    LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
+    for (Dialect *dialect : dialects) {
+      auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+      if (!iface)
+        continue;
+      LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
+                              << dialect->getNamespace() << "\n");
+      iface->loadDependentDialects(context);
+    }
+  }
+
+  /// Return a copy of this extension.
+  virtual std::unique_ptr<DialectExtensionBase> clone() const final {
+    return std::make_unique<LoadDependentDialectExtension>(*this);
+  }
+};
+
+/// This is a generic pass to convert to LLVM, it uses the
+/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
+/// the injection of conversion patterns.
+class ConvertToLLVMPass
+    : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
+  std::shared_ptr<const FrozenRewritePatternSet> patterns;
+  std::shared_ptr<const ConversionTarget> target;
+
+public:
+  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<LLVM::LLVMDialect>();
+    registry.addExtensions<LoadDependentDialectExtension>();
+  }
+
+  ConvertToLLVMPass(const ConvertToLLVMPass &other)
+      : ConvertToLLVMPassBase(other), patterns(other.patterns),
+        target(other.target) {}
+
+  LogicalResult initialize(MLIRContext *context) final {
+    RewritePatternSet tempPatterns(context);
+    auto target = std::make_shared<ConversionTarget>(*context);
+    target->addLegalDialect<LLVM::LLVMDialect>();
+    for (Dialect *dialect : context->getLoadedDialects()) {
+      // First time we encounter this dialect: if it implements the interface,
+      // let's populate patterns !
+      auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+      if (!iface)
+        continue;
+      iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns);
+    }
+    patterns =
+        std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+    this->target = target;
+    return success();
+  }
+
+  void runOnOperation() final {
+    if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+      signalPassFailure();
+  }
+};
+
+} // namespace

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
new file mode 100644
index 00000000000000..6be3defd8781ee
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -0,0 +1,31 @@
+//===- ToLLVMInterface.cpp - MLIR LLVM Conversion -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+void mlir::populateConversionTargetFromOperation(Operation *root,
+                                                 ConversionTarget &target,
+                                                 RewritePatternSet &patterns) {
+  DenseSet<Dialect *> dialects;
+  root->walk([&](Operation *op) {
+    Dialect *dialect = op->getDialect();
+    if (!dialects.insert(dialect).second)
+      return;
+    // First time we encounter this dialect: if it implements the interface,
+    // let's populate patterns !
+    auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+    if (!iface)
+      return;
+    iface->populateConvertToLLVMConversionPatterns(target, patterns);
+  });
+}

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 330ea77c401295..6d2726f949d9c9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -190,8 +191,29 @@ struct ConvertNVVMToLLVMPass
   }
 };
 
+/// Implement the interface to convert NNVM to LLVM.
+struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<NVVMDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, RewritePatternSet &patterns) const final {
+    populateNVVMToLLVMConversionPatterns(patterns);
+  }
+};
+
 } // namespace
 
 void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
   patterns.add<PtxLowering>(patterns.getContext());
 }
+
+void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
+    dialect->addInterfaces<NVVMToLLVMDialectInterface>();
+  });
+}

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0dc269330cbb7c..4c1f92983887ba 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -16,6 +16,7 @@
 
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -721,6 +722,7 @@ void NVVMDialect::initialize() {
   // Support unknown operations because not all NVVM operations are
   // registered.
   allowUnknownOperations();
+  declarePromisedInterface<ConvertToLLVMPatternInterface>();
 }
 
 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 501f52b83e026e..c4e01ca5a8ae43 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -209,6 +209,11 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
   // Functor used to try to apply the given extension.
   auto applyExtension = [&](const DialectExtensionBase &extension) {
     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
+    // An empty set is equivalent to always invoke.
+    if (dialectNames.empty()) {
+      extension.apply(ctx, dialect);
+      return;
+    }
 
     // Handle the simple case of a single dialect name. In this case, the
     // required dialect should be the current dialect.
@@ -251,6 +256,11 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
   // Functor used to try to apply the given extension.
   auto applyExtension = [&](const DialectExtensionBase &extension) {
     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
+    if (dialectNames.empty()) {
+      auto loadedDialects = ctx->getLoadedDialects();
+      extension.apply(ctx, loadedDialects);
+      return;
+    }
 
     // Check to see if all of the dialects for this extension are loaded.
     SmallVector<Dialect *> requiredDialects;

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d93072b695243..9ba913b9d3ea2a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -1,4 +1,7 @@
 // RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
+// Same below, but using the `ConvertToLLVMPatternInterface` entry point
+// and the generic `convert-to-llvm` pass.
+// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL : @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {


        


More information about the Mlir-commits mailing list