[Mlir-commits] [mlir] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces (PR #99566)

Fabian Mora llvmlistbot at llvm.org
Sat Nov 23 14:53:44 PST 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/99566

>From f4c3955aaa5228113d3ad0cdb859a890edbadf48 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 25 Jul 2024 20:53:36 +0000
Subject: [PATCH 1/3] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and
 `ConvertToLLVMOpInterface` interfaces

This patch adds the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available.

The `ConvertToLLVMAttrInterface` interfaces allows attributes to configure conversion to LLVM, including the conversion target, LLVM type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure conversion to LLVM.

The `ConvertToLLVMOpInterface` interface collects all convert to LLVM attributes stored in an operation.

Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
---
 mlir/include/mlir/Conversion/CMakeLists.txt   |  2 +
 .../Conversion/ConvertToLLVM/CMakeLists.txt   |  7 ++
 .../ConvertToLLVM/ToLLVMInterface.h           | 13 +++
 .../ConvertToLLVM/ToLLVMInterface.td          | 76 +++++++++++++++++
 .../mlir/Conversion/GPUCommon/GPUToLLVM.h     | 25 ++++++
 .../mlir/Conversion/GPUToNVVM/GPUToNVVM.h     | 27 ++++++
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |  4 +
 mlir/include/mlir/InitAllExtensions.h         |  4 +
 .../Conversion/ConvertToLLVM/CMakeLists.txt   |  1 +
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 42 ++++++----
 .../ConvertToLLVM/ToLLVMInterface.cpp         | 17 ++++
 .../GPUCommon/GPUToLLVMConversion.cpp         | 32 +++++++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 83 ++++++++++++++-----
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    | 42 ++++++++++
 14 files changed, 335 insertions(+), 40 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
 create mode 100644 mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
 create mode 100644 mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
 create mode 100644 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt
index d212bf3e395e71..9f76ab659215ea 100644
--- a/mlir/include/mlir/Conversion/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/CMakeLists.txt
@@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
 add_public_tablegen_target(MLIRConversionPassIncGen)
 
 add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
+
+add_subdirectory(ConvertToLLVM)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..54d7a03fc22dff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -0,0 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td)
+mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 00aeed9bf29dc2..1d14ff30eb5516 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
 class ConversionTarget;
@@ -50,6 +51,18 @@ void populateConversionTargetFromOperation(Operation *op,
                                            LLVMTypeConverter &typeConverter,
                                            RewritePatternSet &patterns);
 
+/// Helper function for populating LLVM conversion patterns. If `op` implements
+/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern
+/// attributes provided by the interface will be used to configure the
+/// conversion target, type converter, and the pattern set.
+void populateOpConvertToLLVMConversionPatterns(Operation *op,
+                                               ConversionTarget &target,
+                                               LLVMTypeConverter &typeConverter,
+                                               RewritePatternSet &patterns);
 } // namespace mlir
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc"
+
 #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
new file mode 100644
index 00000000000000..1331a9802c570f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
@@ -0,0 +1,76 @@
+
+//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines interfaces for managing transformations, including populating
+// pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMAttrInterface :
+    AttrInterface<"ConvertToLLVMAttrInterface"> {
+  let description = [{
+    The `ConvertToLLVMAttrInterface` attribute interfaces allows using
+    attributes to configure the convert to LLVM infrastructure, this includes:
+     - The conversion target.
+     - The LLVM type converter.
+     - The pattern set.
+
+    This interface permits fined grained configuration of the `convert-to-llvm`
+    process. For example, attributes with target information like
+    `#nvvm.target` or `#rodcl.target` can leverage this interface for populating
+    patterns specific to a particular target.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the dialect conversion target, type converter and pattern set.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"populateConvertToLLVMConversionPatterns",
+      /*args=*/(ins "::mlir::ConversionTarget&":$target,
+                    "::mlir::LLVMTypeConverter&":$typeConverter,
+                    "::mlir::RewritePatternSet&":$patternSet)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Op interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> {
+  let description = [{
+    Interface for collecting all convert to LLVM attributes stored in an
+    operation. See `ConvertToLLVMAttrInterface` for more information on these
+    attributes.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the provided vector with a list of convert to LLVM attributes
+        to apply.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getConvertToLLVMConversionAttrs",
+      /*args=*/(ins
+        "::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs)
+    >
+  ];
+}
+
+#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
new file mode 100644
index 00000000000000..ad8c39fe676618
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
@@ -0,0 +1,25 @@
+//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to LLVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace gpu {
+/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP`
+/// operation.
+void registerConvertGpuToLLVMInterface(DialectRegistry &registry);
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
new file mode 100644
index 00000000000000..6311630a23c8f6
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
@@ -0,0 +1,27 @@
+//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace NVVM {
+/// Registers the `ConvertToLLVMAttrInterface` interface on the
+/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion
+/// target, LLVM type converter, and pattern set for converting GPU operations
+/// to NVVM.
+void registerConvertGpuToNVVMInterface(DialectRegistry &registry);
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 645e86a4309621..fc7c967f1b62cf 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 
+/// Configure the LLVM type convert to convert types and address spaces from the
+/// GPU dialect to NVVM.
+void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
+
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 1f2ef26b450701..14a6a2787b3a5d 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -18,6 +18,8 @@
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -72,6 +74,8 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
   registerConvertAMXToLLVMInterface(registry);
+  gpu::registerConvertGpuToLLVMInterface(registry);
+  NVVM::registerConvertGpuToNVVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
index de3d850d520c0f..a9b49391e36a70 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface
   ToLLVMInterface.cpp
 
   DEPENDS
+  MLIRConvertToLLVMInterfaceIncGen
 
   LINK_LIBS PUBLIC
   MLIRIR
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index b2407a258c2719..0f3e852977ada2 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -63,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
 /// the injection of conversion patterns.
 class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
-  std::shared_ptr<const FrozenRewritePatternSet> patterns;
-  std::shared_ptr<const ConversionTarget> target;
-  std::shared_ptr<const LLVMTypeConverter> typeConverter;
+  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+      interfaces;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -75,11 +74,8 @@ class ConvertToLLVMPass
   }
 
   LogicalResult initialize(MLIRContext *context) final {
-    RewritePatternSet tempPatterns(context);
-    auto target = std::make_shared<ConversionTarget>(*context);
-    target->addLegalDialect<LLVM::LLVMDialect>();
-    auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
-
+    auto interfaces =
+        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
     if (!filterDialects.empty()) {
       // Test mode: Populate only patterns from the specified dialects. Produce
       // an error if the dialect is not loaded or does not implement the
@@ -94,8 +90,7 @@ class ConvertToLLVMPass
           return emitError(UnknownLoc::get(context))
                  << "dialect does not implement ConvertToLLVMPatternInterface: "
                  << dialectName << "\n";
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     } else {
       // Normal mode: Populate all patterns from all dialects that implement the
@@ -106,20 +101,33 @@ class ConvertToLLVMPass
         auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
         if (!iface)
           continue;
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     }
 
-    this->patterns =
-        std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
-    this->target = target;
-    this->typeConverter = typeConverter;
+    this->interfaces = interfaces;
     return success();
   }
 
   void runOnOperation() final {
-    if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    LLVMTypeConverter typeConverter(context);
+
+    // Configure the conversion with dialect level interfaces.
+    for (ConvertToLLVMPatternInterface *iface : *interfaces)
+      iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                     patterns);
+
+    // Configure the conversion attribute interfaces.
+    populateOpConvertToLLVMConversionPatterns(getOperation(), target,
+                                              typeConverter, patterns);
+
+    // Apply the conversion.
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 3a4e83b2a8838f..5cc71178b0f28d 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -30,3 +30,20 @@ void mlir::populateConversionTargetFromOperation(
                                                    patterns);
   });
 }
+
+void mlir::populateOpConvertToLLVMConversionPatterns(
+    Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+  if (!iface)
+    return;
+  SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
+  iface.getConvertToLLVMConversionAttrs(attrs);
+  for (ConvertToLLVMAttrInterface attr : attrs)
+    attr.populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                 patterns);
+}
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.cpp.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.cpp.inc"
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c58737..1497d662dcdbdd 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -1762,3 +1763,34 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
 }
+
+//===----------------------------------------------------------------------===//
+// GPUModuleOp convert to LLVM op interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct GPUModuleOpConvertToLLVMInterface
+    : public ConvertToLLVMOpInterface::ExternalModel<
+          GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
+  /// Get the conversion patterns from the target attribute.
+  void getConvertToLLVMConversionAttrs(
+      Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
+};
+} // namespace
+
+void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
+    Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
+  auto module = cast<gpu::GPUModuleOp>(op);
+  ArrayAttr targetsAttr = module.getTargetsAttr();
+  // Fail if there are no target attributes or there is more than one target.
+  if (!targetsAttr || targetsAttr.size() != 1)
+    return;
+  if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
+    attrs.push_back(patternAttr);
+}
+
+void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 04e85c2b337dec..b343cf71e3a2e7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -15,8 +15,10 @@
 
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -274,29 +276,7 @@ struct LowerGpuOpsToNVVMOpsPass
     }
 
     LLVMTypeConverter converter(m.getContext(), options);
-    // NVVM uses alloca in the default address space to represent private
-    // memory allocations, so drop private annotations. NVVM uses address
-    // space 3 for shared memory. NVVM uses the default address space to
-    // represent global memory.
-    populateGpuMemorySpaceAttributeConversions(
-        converter, [](gpu::AddressSpace space) -> unsigned {
-          switch (space) {
-          case gpu::AddressSpace::Global:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-          case gpu::AddressSpace::Workgroup:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kSharedMemorySpace);
-          case gpu::AddressSpace::Private:
-            return 0;
-          }
-          llvm_unreachable("unknown address space enum value");
-          return 0;
-        });
-    // Lowering for MMAMatrixType.
-    converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToLLVMType(type);
-    });
+    configureGpuToNVVMTypeConverter(converter);
     RewritePatternSet llvmPatterns(m.getContext());
 
     arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
@@ -332,6 +312,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
 }
 
+void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
+  // NVVM uses alloca in the default address space to represent private
+  // memory allocations, so drop private annotations. NVVM uses address
+  // space 3 for shared memory. NVVM uses the default address space to
+  // represent global memory.
+  populateGpuMemorySpaceAttributeConversions(
+      converter, [](gpu::AddressSpace space) -> unsigned {
+        switch (space) {
+        case gpu::AddressSpace::Global:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+        case gpu::AddressSpace::Workgroup:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kSharedMemorySpace);
+        case gpu::AddressSpace::Private:
+          return 0;
+        }
+        llvm_unreachable("unknown address space enum value");
+        return 0;
+      });
+  // Lowering for MMAMatrixType.
+  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+    return convertMMAToLLVMType(type);
+  });
+}
+
 template <typename OpTy>
 static void populateOpPatterns(const LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
@@ -467,3 +473,34 @@ void mlir::populateGpuToNVVMConversionPatterns(
   populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
                                    "__nv_tanh");
 }
+
+//===----------------------------------------------------------------------===//
+// NVVMTargetAttr convert to LLVM attr interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct NVVMTargetConvertToLLVMAttrInterface
+    : public ConvertToLLVMAttrInterface::ExternalModel<
+          NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
+  /// Configure GPU to NVVM.
+  void populateConvertToLLVMConversionPatterns(
+      Attribute attr, ConversionTarget &target,
+      LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
+};
+} // namespace
+
+void NVVMTargetConvertToLLVMAttrInterface::
+    populateConvertToLLVMConversionPatterns(Attribute attr,
+                                            ConversionTarget &target,
+                                            LLVMTypeConverter &typeConverter,
+                                            RewritePatternSet &patterns) const {
+  configureGpuToNVVMConversionLegality(target);
+  configureGpuToNVVMTypeConverter(typeConverter);
+  populateGpuToNVVMConversionPatterns(typeConverter, patterns);
+}
+
+void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
+    NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
+  });
+}
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
new file mode 100644
index 00000000000000..6e55a56d8c9aa5
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+
+// CHECK-LABEL: gpu.module @nvvm_module
+gpu.module @nvvm_module [#nvvm.target] {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %tIdX = gpu.thread_id x
+    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: llvm.func @kernel_1
+// CHECK: (%{{.*}}: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64)
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>} 
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}
+
+// CHECK-LABEL: gpu.module @nvvm_module_2
+gpu.module @nvvm_module_2 {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = gpu.thread_id x
+    %tIdX = gpu.thread_id x
+    // CHECK: = gpu.lane_id
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK: (%{{.*}}: memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>}
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}

>From 97e79915d8b1d96b0c3720a1c8d5265123214bb8 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 29 Jul 2024 18:56:18 +0000
Subject: [PATCH 2/3] add option to control whether to use conversion
 attributes

---
 mlir/include/mlir/Conversion/Passes.td        |  6 +++
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 51 +++++++++++++++++--
 .../ConvertToLLVM/ToLLVMInterface.cpp         |  2 +
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    |  2 +-
 4 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d272ba219c6f1..57d187bd6dcb32 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -22,12 +22,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
     This is a generic pass to convert to LLVM, it uses the
     `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
     the injection of conversion patterns.
+
+    If `use-conversion-attrs` is set to `true`, the pass will look for
+    `ConvertToLLVMAttrInterface` attributes and use them to further configure
+    the conversion process. Enabling this option incurs in extra overhead.
   }];
 
   let constructor = "mlir::createConvertToLLVMPass()";
   let options = [
     ListOption<"filterDialects", "filter-dialects", "std::string",
                "Test conversion patterns of only the specified dialects">,
+    Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
+           "Use op conversion attributes to configure the conversion">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 0f3e852977ada2..a06de03aa5ec42 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -65,6 +65,9 @@ class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
   std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
       interfaces;
+  std::shared_ptr<const FrozenRewritePatternSet> patterns;
+  std::shared_ptr<const ConversionTarget> target;
+  std::shared_ptr<const LLVMTypeConverter> typeConverter;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -74,8 +77,22 @@ class ConvertToLLVMPass
   }
 
   LogicalResult initialize(MLIRContext *context) final {
-    auto interfaces =
-        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+    std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
+    std::shared_ptr<ConversionTarget> target;
+    std::shared_ptr<LLVMTypeConverter> typeConverter;
+    RewritePatternSet tempPatterns(context);
+
+    // Only collect the interfaces if `useConversionAttrs=true` as everything
+    // else must be initialized in `runOnOperation`.
+    if (useConversionAttrs) {
+      interfaces =
+          std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+    } else {
+      target = std::make_shared<ConversionTarget>(*context);
+      target->addLegalDialect<LLVM::LLVMDialect>();
+      typeConverter = std::make_shared<LLVMTypeConverter>(context);
+    }
+
     if (!filterDialects.empty()) {
       // Test mode: Populate only patterns from the specified dialects. Produce
       // an error if the dialect is not loaded or does not implement the
@@ -90,7 +107,12 @@ class ConvertToLLVMPass
           return emitError(UnknownLoc::get(context))
                  << "dialect does not implement ConvertToLLVMPatternInterface: "
                  << dialectName << "\n";
-        interfaces->push_back(iface);
+        if (useConversionAttrs) {
+          interfaces->push_back(iface);
+          continue;
+        }
+        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+                                                       tempPatterns);
       }
     } else {
       // Normal mode: Populate all patterns from all dialects that implement the
@@ -101,15 +123,34 @@ class ConvertToLLVMPass
         auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
         if (!iface)
           continue;
-        interfaces->push_back(iface);
+        if (useConversionAttrs) {
+          interfaces->push_back(iface);
+          continue;
+        }
+        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+                                                       tempPatterns);
       }
     }
 
-    this->interfaces = interfaces;
+    if (useConversionAttrs) {
+      this->interfaces = interfaces;
+    } else {
+      this->patterns =
+          std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+      this->target = target;
+      this->typeConverter = typeConverter;
+    }
     return success();
   }
 
   void runOnOperation() final {
+    // Fast path:
+    if (!useConversionAttrs) {
+      if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+        signalPassFailure();
+      return;
+    }
+    // Slow path with conversion attributes.
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 5cc71178b0f28d..252245dfbf5417 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -35,6 +35,8 @@ void mlir::populateOpConvertToLLVMConversionPatterns(
     Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
     RewritePatternSet &patterns) {
   auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+  if (!iface)
+    iface = op->getParentOfType<ConvertToLLVMOpInterface>();
   if (!iface)
     return;
   SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index 6e55a56d8c9aa5..132eb473cb9189 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
 
 // CHECK-LABEL: gpu.module @nvvm_module
 gpu.module @nvvm_module [#nvvm.target] {

>From f550639aa18f7d3cf0ea0fe30733925da394433a Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 23 Nov 2024 22:53:26 +0000
Subject: [PATCH 3/3] address reviewer comments

---
 .../ConvertToLLVM/ToLLVMInterface.h           |  37 +++
 mlir/include/mlir/Conversion/Passes.td        |   8 +-
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 243 +++++++++++-------
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    |   4 +-
 4 files changed, 198 insertions(+), 94 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 1d14ff30eb5516..af8150d1d329e2 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -19,6 +19,7 @@ class LLVMTypeConverter;
 class MLIRContext;
 class Operation;
 class RewritePatternSet;
+class AnalysisManager;
 
 /// Base class for dialect interfaces providing translation to LLVM IR.
 /// Dialects that can be translated should provide an implementation of this
@@ -59,6 +60,42 @@ void populateOpConvertToLLVMConversionPatterns(Operation *op,
                                                ConversionTarget &target,
                                                LLVMTypeConverter &typeConverter,
                                                RewritePatternSet &patterns);
+
+/// Base class for creating the internal implementation of `convert-to-llvm`
+/// passes.
+class ConvertToLLVMPassInterface {
+public:
+  ConvertToLLVMPassInterface(MLIRContext *context,
+                             ArrayRef<std::string> filterDialects);
+  virtual ~ConvertToLLVMPassInterface() = default;
+
+  /// Get the dependent dialects used by `convert-to-llvm`.
+  static void getDependentDialects(DialectRegistry &registry);
+
+  /// Initialize the internal state of the `convert-to-llvm` pass
+  /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
+  /// This method returns whether the initialization process failed.
+  virtual LogicalResult initialize() = 0;
+
+  /// Transform `op` to LLVM with the conversions available in the pass. The
+  /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
+  /// to further configure the conversion process. This method is invoked by
+  /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
+  /// transformation process failed.
+  virtual LogicalResult transform(Operation *op,
+                                  AnalysisManager manager) const = 0;
+
+protected:
+  /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
+  /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
+  /// then `visitor` is invoked only with the dialects in the `filterDialects`
+  /// list.
+  LogicalResult visitInterfaces(
+      llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
+  MLIRContext *context;
+  /// List of dialects names to use as filters.
+  ArrayRef<std::string> filterDialects;
+};
 } // namespace mlir
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 57d187bd6dcb32..e394bae64e0918 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -23,16 +23,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
     `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
     the injection of conversion patterns.
 
-    If `use-conversion-attrs` is set to `true`, the pass will look for
+    If `dynamic` is set to `true`, the pass will look for
     `ConvertToLLVMAttrInterface` attributes and use them to further configure
-    the conversion process. Enabling this option incurs in extra overhead.
+    the conversion process. This option also uses the `DataLayoutAnalysis`
+    analysis to configure the type converter. Enabling this option incurs in
+    extra overhead.
   }];
 
   let constructor = "mlir::createConvertToLLVMPass()";
   let options = [
     ListOption<"filterDialects", "filter-dialects", "std::string",
                "Test conversion patterns of only the specified dialects">,
-    Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
+    Option<"useDynamic", "dynamic", "bool", "false",
            "Use op conversion attributes to configure the conversion">,
   ];
 }
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index a06de03aa5ec42..f9b88d80695aca 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -27,7 +28,6 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-
 /// This DialectExtension can be attached to the context, which will invoke the
 /// `apply()` method for every loaded dialect. If a dialect implements the
 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
@@ -58,104 +58,82 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
   }
 };
 
-/// This is a generic pass to convert to LLVM, it uses the
-/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
-/// the injection of conversion patterns.
-class ConvertToLLVMPass
-    : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
-  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
-      interfaces;
+//===----------------------------------------------------------------------===//
+// StaticConvertToLLVM
+//===----------------------------------------------------------------------===//
+
+/// Static implementation of the `convert-to-llvm` pass. This version only looks
+/// at dialect interfaces to configure the conversion process.
+struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
+  /// Pattern set with conversions to LLVM.
   std::shared_ptr<const FrozenRewritePatternSet> patterns;
+  /// The conversion target.
   std::shared_ptr<const ConversionTarget> target;
+  /// The LLVM type converter.
   std::shared_ptr<const LLVMTypeConverter> typeConverter;
+  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
 
-public:
-  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
-  void getDependentDialects(DialectRegistry &registry) const final {
-    registry.insert<LLVM::LLVMDialect>();
-    registry.addExtensions<LoadDependentDialectExtension>();
+  /// Configure the conversion to LLVM at pass initialization.
+  LogicalResult initialize() final {
+    auto target = std::make_shared<ConversionTarget>(*context);
+    auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
+    RewritePatternSet tempPatterns(context);
+    target->addLegalDialect<LLVM::LLVMDialect>();
+    // Populate the patterns with the dialect interface.
+    if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
+          iface->populateConvertToLLVMConversionPatterns(
+              *target, *typeConverter, tempPatterns);
+        })))
+      return failure();
+    this->patterns =
+        std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+    this->target = target;
+    this->typeConverter = typeConverter;
+    return success();
   }
 
-  LogicalResult initialize(MLIRContext *context) final {
-    std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
-    std::shared_ptr<ConversionTarget> target;
-    std::shared_ptr<LLVMTypeConverter> typeConverter;
-    RewritePatternSet tempPatterns(context);
+  /// Apply the conversion driver.
+  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
+    if (failed(applyPartialConversion(op, *target, *patterns)))
+      return failure();
+    return success();
+  }
+};
 
-    // Only collect the interfaces if `useConversionAttrs=true` as everything
-    // else must be initialized in `runOnOperation`.
-    if (useConversionAttrs) {
-      interfaces =
-          std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
-    } else {
-      target = std::make_shared<ConversionTarget>(*context);
-      target->addLegalDialect<LLVM::LLVMDialect>();
-      typeConverter = std::make_shared<LLVMTypeConverter>(context);
-    }
+//===----------------------------------------------------------------------===//
+// DynamicConvertToLLVM
+//===----------------------------------------------------------------------===//
 
-    if (!filterDialects.empty()) {
-      // Test mode: Populate only patterns from the specified dialects. Produce
-      // an error if the dialect is not loaded or does not implement the
-      // interface.
-      for (std::string &dialectName : filterDialects) {
-        Dialect *dialect = context->getLoadedDialect(dialectName);
-        if (!dialect)
-          return emitError(UnknownLoc::get(context))
-                 << "dialect not loaded: " << dialectName << "\n";
-        auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
-        if (!iface)
-          return emitError(UnknownLoc::get(context))
-                 << "dialect does not implement ConvertToLLVMPatternInterface: "
-                 << dialectName << "\n";
-        if (useConversionAttrs) {
-          interfaces->push_back(iface);
-          continue;
-        }
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
-      }
-    } else {
-      // Normal mode: Populate all patterns from all dialects that implement the
-      // interface.
-      for (Dialect *dialect : context->getLoadedDialects()) {
-        // First time we encounter this dialect: if it implements the interface,
-        // let's populate patterns !
-        auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
-        if (!iface)
-          continue;
-        if (useConversionAttrs) {
+/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
+/// the IR to configure the conversion to LLVM.
+struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
+  /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
+  /// to partially configure the conversion process.
+  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+      interfaces;
+  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
+
+  /// Collect the dialect interfaces used to configure the conversion process.
+  LogicalResult initialize() final {
+    auto interfaces =
+        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+    // Collect the interfaces.
+    if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
           interfaces->push_back(iface);
-          continue;
-        }
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
-      }
-    }
-
-    if (useConversionAttrs) {
-      this->interfaces = interfaces;
-    } else {
-      this->patterns =
-          std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
-      this->target = target;
-      this->typeConverter = typeConverter;
-    }
+        })))
+      return failure();
+    this->interfaces = interfaces;
     return success();
   }
 
-  void runOnOperation() final {
-    // Fast path:
-    if (!useConversionAttrs) {
-      if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
-        signalPassFailure();
-      return;
-    }
-    // Slow path with conversion attributes.
-    MLIRContext *context = &getContext();
+  /// Configure the conversion process and apply the conversion driver.
+  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
     target.addLegalDialect<LLVM::LLVMDialect>();
-    LLVMTypeConverter typeConverter(context);
+    // Get the data layout analysis.
+    const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
+    LLVMTypeConverter typeConverter(context, &dlAnalysis);
 
     // Configure the conversion with dialect level interfaces.
     for (ConvertToLLVMPatternInterface *iface : *interfaces)
@@ -163,18 +141,105 @@ class ConvertToLLVMPass
                                                      patterns);
 
     // Configure the conversion attribute interfaces.
-    populateOpConvertToLLVMConversionPatterns(getOperation(), target,
-                                              typeConverter, patterns);
+    populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
+                                              patterns);
 
     // Apply the conversion.
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      return failure();
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPass
+//===----------------------------------------------------------------------===//
+
+/// This is a generic pass to convert to LLVM, it uses the
+/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
+/// the injection of conversion patterns.
+class ConvertToLLVMPass
+    : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
+  std::shared_ptr<const ConvertToLLVMPassInterface> impl;
+
+public:
+  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
+  void getDependentDialects(DialectRegistry &registry) const final {
+    ConvertToLLVMPassInterface::getDependentDialects(registry);
+  }
+
+  LogicalResult initialize(MLIRContext *context) final {
+    std::shared_ptr<ConvertToLLVMPassInterface> impl;
+    // Choose the pass implementation.
+    if (useDynamic)
+      impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+    else
+      impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+    if (failed(impl->initialize()))
+      return failure();
+    this->impl = impl;
+    return success();
+  }
+
+  void runOnOperation() final {
+    if (failed(impl->transform(getOperation(), getAnalysisManager())))
+      return signalPassFailure();
   }
 };
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPassInterface
+//===----------------------------------------------------------------------===//
+
+ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
+    MLIRContext *context, ArrayRef<std::string> filterDialects)
+    : context(context), filterDialects(filterDialects) {}
+
+void ConvertToLLVMPassInterface::getDependentDialects(
+    DialectRegistry &registry) {
+  registry.insert<LLVM::LLVMDialect>();
+  registry.addExtensions<LoadDependentDialectExtension>();
+}
+
+LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
+    llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
+  if (!filterDialects.empty()) {
+    // Test mode: Populate only patterns from the specified dialects. Produce
+    // an error if the dialect is not loaded or does not implement the
+    // interface.
+    for (StringRef dialectName : filterDialects) {
+      Dialect *dialect = context->getLoadedDialect(dialectName);
+      if (!dialect)
+        return emitError(UnknownLoc::get(context))
+               << "dialect not loaded: " << dialectName << "\n";
+      auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+      if (!iface)
+        return emitError(UnknownLoc::get(context))
+               << "dialect does not implement ConvertToLLVMPatternInterface: "
+               << dialectName << "\n";
+      visitor(iface);
+    }
+  } else {
+    // Normal mode: Populate all patterns from all dialects that implement the
+    // interface.
+    for (Dialect *dialect : context->getLoadedDialects()) {
+      // First time we encounter this dialect: if it implements the interface,
+      // let's populate patterns !
+      auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+      if (!iface)
+        continue;
+      visitor(iface);
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
 void mlir::registerConvertToLLVMDependentDialectLoading(
     DialectRegistry &registry) {
   registry.addExtensions<LoadDependentDialectExtension>();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index 132eb473cb9189..ed7fa6508d5ade 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s
 
 // CHECK-LABEL: gpu.module @nvvm_module
 gpu.module @nvvm_module [#nvvm.target] {
@@ -7,7 +7,7 @@ gpu.module @nvvm_module [#nvvm.target] {
     // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
     // CHECK: = llvm.sext %{{.*}} : i32 to i64
     %tIdX = gpu.thread_id x
-    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
     // CHECK: = llvm.sext %{{.*}} : i32 to i64
     %laneId = gpu.lane_id
     %sum = index.add %tIdX, %laneId



More information about the Mlir-commits mailing list