[Mlir-commits] [mlir] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces (PR #99566)

Fabian Mora llvmlistbot at llvm.org
Mon Jul 29 11:58:14 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/99566

>From 78c7b401ebce2bc044e9407d39e1a81e631225b6 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 25 Jul 2024 20:53:36 +0000
Subject: [PATCH 1/2] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and
 `ConvertToLLVMOpInterface` interfaces

This patch adds the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available.

The `ConvertToLLVMAttrInterface` interfaces allows attributes to configure conversion to LLVM, including the conversion target, LLVM type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure conversion to LLVM.

The `ConvertToLLVMOpInterface` interface collects all convert to LLVM attributes stored in an operation.

Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
---
 mlir/include/mlir/Conversion/CMakeLists.txt   |  2 +
 .../Conversion/ConvertToLLVM/CMakeLists.txt   |  7 ++
 .../ConvertToLLVM/ToLLVMInterface.h           | 13 +++
 .../ConvertToLLVM/ToLLVMInterface.td          | 76 +++++++++++++++++
 .../mlir/Conversion/GPUCommon/GPUToLLVM.h     | 25 ++++++
 .../mlir/Conversion/GPUToNVVM/GPUToNVVM.h     | 27 ++++++
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |  4 +
 mlir/include/mlir/InitAllExtensions.h         |  4 +
 .../Conversion/ConvertToLLVM/CMakeLists.txt   |  1 +
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 42 ++++++----
 .../ConvertToLLVM/ToLLVMInterface.cpp         | 17 ++++
 .../GPUCommon/GPUToLLVMConversion.cpp         | 32 +++++++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 83 ++++++++++++++-----
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    | 42 ++++++++++
 14 files changed, 335 insertions(+), 40 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
 create mode 100644 mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
 create mode 100644 mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
 create mode 100644 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt
index d212bf3e395e7..9f76ab659215e 100644
--- a/mlir/include/mlir/Conversion/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/CMakeLists.txt
@@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
 add_public_tablegen_target(MLIRConversionPassIncGen)
 
 add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
+
+add_subdirectory(ConvertToLLVM)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..54d7a03fc22df
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -0,0 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td)
+mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 00aeed9bf29dc..1d14ff30eb551 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
 class ConversionTarget;
@@ -50,6 +51,18 @@ void populateConversionTargetFromOperation(Operation *op,
                                            LLVMTypeConverter &typeConverter,
                                            RewritePatternSet &patterns);
 
+/// Helper function for populating LLVM conversion patterns. If `op` implements
+/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern
+/// attributes provided by the interface will be used to configure the
+/// conversion target, type converter, and the pattern set.
+void populateOpConvertToLLVMConversionPatterns(Operation *op,
+                                               ConversionTarget &target,
+                                               LLVMTypeConverter &typeConverter,
+                                               RewritePatternSet &patterns);
 } // namespace mlir
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc"
+
 #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
new file mode 100644
index 0000000000000..1331a9802c570
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
@@ -0,0 +1,76 @@
+
+//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines interfaces for managing transformations, including populating
+// pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMAttrInterface :
+    AttrInterface<"ConvertToLLVMAttrInterface"> {
+  let description = [{
+    The `ConvertToLLVMAttrInterface` attribute interfaces allows using
+    attributes to configure the convert to LLVM infrastructure, this includes:
+     - The conversion target.
+     - The LLVM type converter.
+     - The pattern set.
+
+    This interface permits fined grained configuration of the `convert-to-llvm`
+    process. For example, attributes with target information like
+    `#nvvm.target` or `#rodcl.target` can leverage this interface for populating
+    patterns specific to a particular target.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the dialect conversion target, type converter and pattern set.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"populateConvertToLLVMConversionPatterns",
+      /*args=*/(ins "::mlir::ConversionTarget&":$target,
+                    "::mlir::LLVMTypeConverter&":$typeConverter,
+                    "::mlir::RewritePatternSet&":$patternSet)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Op interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> {
+  let description = [{
+    Interface for collecting all convert to LLVM attributes stored in an
+    operation. See `ConvertToLLVMAttrInterface` for more information on these
+    attributes.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the provided vector with a list of convert to LLVM attributes
+        to apply.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getConvertToLLVMConversionAttrs",
+      /*args=*/(ins
+        "::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs)
+    >
+  ];
+}
+
+#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
new file mode 100644
index 0000000000000..ad8c39fe67661
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
@@ -0,0 +1,25 @@
+//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to LLVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace gpu {
+/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP`
+/// operation.
+void registerConvertGpuToLLVMInterface(DialectRegistry &registry);
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
new file mode 100644
index 0000000000000..6311630a23c8f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
@@ -0,0 +1,27 @@
+//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace NVVM {
+/// Registers the `ConvertToLLVMAttrInterface` interface on the
+/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion
+/// target, LLVM type converter, and pattern set for converting GPU operations
+/// to NVVM.
+void registerConvertGpuToNVVMInterface(DialectRegistry &registry);
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index e0f4c71051e50..61741a2678c7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 
+/// Configure the LLVM type convert to convert types and address spaces from the
+/// GPU dialect to NVVM.
+void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
+
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..e7c7616f93c4b 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -18,6 +18,8 @@
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -65,6 +67,8 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertMemRefToLLVMInterface(registry);
   registerConvertNVVMToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  gpu::registerConvertGpuToLLVMInterface(registry);
+  NVVM::registerConvertGpuToNVVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
index df7e3f995303c..4dca9bb3869fb 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface
   ToLLVMInterface.cpp
 
   DEPENDS
+  MLIRConvertToLLVMInterfaceIncGen
 
   LINK_LIBS PUBLIC
   MLIRIR
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 6135117348a5b..d2d01fb3d80d9 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -61,9 +61,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
 /// the injection of conversion patterns.
 class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
-  std::shared_ptr<const FrozenRewritePatternSet> patterns;
-  std::shared_ptr<const ConversionTarget> target;
-  std::shared_ptr<const LLVMTypeConverter> typeConverter;
+  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+      interfaces;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -73,11 +72,8 @@ class ConvertToLLVMPass
   }
 
   LogicalResult initialize(MLIRContext *context) final {
-    RewritePatternSet tempPatterns(context);
-    auto target = std::make_shared<ConversionTarget>(*context);
-    target->addLegalDialect<LLVM::LLVMDialect>();
-    auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
-
+    auto interfaces =
+        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
     if (!filterDialects.empty()) {
       // Test mode: Populate only patterns from the specified dialects. Produce
       // an error if the dialect is not loaded or does not implement the
@@ -92,8 +88,7 @@ class ConvertToLLVMPass
           return emitError(UnknownLoc::get(context))
                  << "dialect does not implement ConvertToLLVMPatternInterface: "
                  << dialectName << "\n";
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     } else {
       // Normal mode: Populate all patterns from all dialects that implement the
@@ -104,20 +99,33 @@ class ConvertToLLVMPass
         auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
         if (!iface)
           continue;
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     }
 
-    this->patterns =
-        std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
-    this->target = target;
-    this->typeConverter = typeConverter;
+    this->interfaces = interfaces;
     return success();
   }
 
   void runOnOperation() final {
-    if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    LLVMTypeConverter typeConverter(context);
+
+    // Configure the conversion with dialect level interfaces.
+    for (ConvertToLLVMPatternInterface *iface : *interfaces)
+      iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                     patterns);
+
+    // Configure the conversion attribute interfaces.
+    populateOpConvertToLLVMConversionPatterns(getOperation(), target,
+                                              typeConverter, patterns);
+
+    // Apply the conversion.
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 3a4e83b2a8838..5cc71178b0f28 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -30,3 +30,20 @@ void mlir::populateConversionTargetFromOperation(
                                                    patterns);
   });
 }
+
+void mlir::populateOpConvertToLLVMConversionPatterns(
+    Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+  if (!iface)
+    return;
+  SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
+  iface.getConvertToLLVMConversionAttrs(attrs);
+  for (ConvertToLLVMAttrInterface attr : attrs)
+    attr.populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                 patterns);
+}
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.cpp.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.cpp.inc"
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c5873..1497d662dcdbd 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -1762,3 +1763,34 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
 }
+
+//===----------------------------------------------------------------------===//
+// GPUModuleOp convert to LLVM op interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct GPUModuleOpConvertToLLVMInterface
+    : public ConvertToLLVMOpInterface::ExternalModel<
+          GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
+  /// Get the conversion patterns from the target attribute.
+  void getConvertToLLVMConversionAttrs(
+      Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
+};
+} // namespace
+
+void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
+    Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
+  auto module = cast<gpu::GPUModuleOp>(op);
+  ArrayAttr targetsAttr = module.getTargetsAttr();
+  // Fail if there are no target attributes or there is more than one target.
+  if (!targetsAttr || targetsAttr.size() != 1)
+    return;
+  if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
+    attrs.push_back(patternAttr);
+}
+
+void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index faa97caacb885..7d7c5da096a68 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -15,8 +15,10 @@
 
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -261,29 +263,7 @@ struct LowerGpuOpsToNVVMOpsPass
     }
 
     LLVMTypeConverter converter(m.getContext(), options);
-    // NVVM uses alloca in the default address space to represent private
-    // memory allocations, so drop private annotations. NVVM uses address
-    // space 3 for shared memory. NVVM uses the default address space to
-    // represent global memory.
-    populateGpuMemorySpaceAttributeConversions(
-        converter, [](gpu::AddressSpace space) -> unsigned {
-          switch (space) {
-          case gpu::AddressSpace::Global:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-          case gpu::AddressSpace::Workgroup:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kSharedMemorySpace);
-          case gpu::AddressSpace::Private:
-            return 0;
-          }
-          llvm_unreachable("unknown address space enum value");
-          return 0;
-        });
-    // Lowering for MMAMatrixType.
-    converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToLLVMType(type);
-    });
+    configureGpuToNVVMTypeConverter(converter);
     RewritePatternSet llvmPatterns(m.getContext());
 
     arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
@@ -319,6 +299,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
 
+void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
+  // NVVM uses alloca in the default address space to represent private
+  // memory allocations, so drop private annotations. NVVM uses address
+  // space 3 for shared memory. NVVM uses the default address space to
+  // represent global memory.
+  populateGpuMemorySpaceAttributeConversions(
+      converter, [](gpu::AddressSpace space) -> unsigned {
+        switch (space) {
+        case gpu::AddressSpace::Global:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+        case gpu::AddressSpace::Workgroup:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kSharedMemorySpace);
+        case gpu::AddressSpace::Private:
+          return 0;
+        }
+        llvm_unreachable("unknown address space enum value");
+        return 0;
+      });
+  // Lowering for MMAMatrixType.
+  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+    return convertMMAToLLVMType(type);
+  });
+}
+
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
@@ -438,3 +444,34 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
   populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
                                    "__nv_tanh");
 }
+
+//===----------------------------------------------------------------------===//
+// NVVMTargetAttr convert to LLVM attr interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct NVVMTargetConvertToLLVMAttrInterface
+    : public ConvertToLLVMAttrInterface::ExternalModel<
+          NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
+  /// Configure GPU to NVVM.
+  void populateConvertToLLVMConversionPatterns(
+      Attribute attr, ConversionTarget &target,
+      LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
+};
+} // namespace
+
+void NVVMTargetConvertToLLVMAttrInterface::
+    populateConvertToLLVMConversionPatterns(Attribute attr,
+                                            ConversionTarget &target,
+                                            LLVMTypeConverter &typeConverter,
+                                            RewritePatternSet &patterns) const {
+  configureGpuToNVVMConversionLegality(target);
+  configureGpuToNVVMTypeConverter(typeConverter);
+  populateGpuToNVVMConversionPatterns(typeConverter, patterns);
+}
+
+void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
+    NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
+  });
+}
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
new file mode 100644
index 0000000000000..6e55a56d8c9aa
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+
+// CHECK-LABEL: gpu.module @nvvm_module
+gpu.module @nvvm_module [#nvvm.target] {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %tIdX = gpu.thread_id x
+    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: llvm.func @kernel_1
+// CHECK: (%{{.*}}: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64)
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>} 
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}
+
+// CHECK-LABEL: gpu.module @nvvm_module_2
+gpu.module @nvvm_module_2 {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = gpu.thread_id x
+    %tIdX = gpu.thread_id x
+    // CHECK: = gpu.lane_id
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK: (%{{.*}}: memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>}
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}

>From ac29bbbea52a495f0d9a34be633e885d00e33e19 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 29 Jul 2024 18:56:18 +0000
Subject: [PATCH 2/2] add option to control whether to use conversion
 attributes

---
 mlir/include/mlir/Conversion/Passes.td        |  6 +++
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 51 +++++++++++++++++--
 .../ConvertToLLVM/ToLLVMInterface.cpp         |  2 +
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    |  2 +-
 4 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..d1ff185d3d9b2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -22,12 +22,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
     This is a generic pass to convert to LLVM, it uses the
     `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
     the injection of conversion patterns.
+
+    If `use-conversion-attrs` is set to `true`, the pass will look for
+    `ConvertToLLVMAttrInterface` attributes and use them to further configure
+    the conversion process. Enabling this option incurs in extra overhead.
   }];
 
   let constructor = "mlir::createConvertToLLVMPass()";
   let options = [
     ListOption<"filterDialects", "filter-dialects", "std::string",
                "Test conversion patterns of only the specified dialects">,
+    Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
+           "Use op conversion attributes to configure the conversion">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index d2d01fb3d80d9..2068b54182aba 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -63,6 +63,9 @@ class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
   std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
       interfaces;
+  std::shared_ptr<const FrozenRewritePatternSet> patterns;
+  std::shared_ptr<const ConversionTarget> target;
+  std::shared_ptr<const LLVMTypeConverter> typeConverter;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -72,8 +75,22 @@ class ConvertToLLVMPass
   }
 
   LogicalResult initialize(MLIRContext *context) final {
-    auto interfaces =
-        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+    std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
+    std::shared_ptr<ConversionTarget> target;
+    std::shared_ptr<LLVMTypeConverter> typeConverter;
+    RewritePatternSet tempPatterns(context);
+
+    // Only collect the interfaces if `useConversionAttrs=true` as everything
+    // else must be initialized in `runOnOperation`.
+    if (useConversionAttrs) {
+      interfaces =
+          std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+    } else {
+      target = std::make_shared<ConversionTarget>(*context);
+      target->addLegalDialect<LLVM::LLVMDialect>();
+      typeConverter = std::make_shared<LLVMTypeConverter>(context);
+    }
+
     if (!filterDialects.empty()) {
       // Test mode: Populate only patterns from the specified dialects. Produce
       // an error if the dialect is not loaded or does not implement the
@@ -88,7 +105,12 @@ class ConvertToLLVMPass
           return emitError(UnknownLoc::get(context))
                  << "dialect does not implement ConvertToLLVMPatternInterface: "
                  << dialectName << "\n";
-        interfaces->push_back(iface);
+        if (useConversionAttrs) {
+          interfaces->push_back(iface);
+          continue;
+        }
+        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+                                                       tempPatterns);
       }
     } else {
       // Normal mode: Populate all patterns from all dialects that implement the
@@ -99,15 +121,34 @@ class ConvertToLLVMPass
         auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
         if (!iface)
           continue;
-        interfaces->push_back(iface);
+        if (useConversionAttrs) {
+          interfaces->push_back(iface);
+          continue;
+        }
+        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+                                                       tempPatterns);
       }
     }
 
-    this->interfaces = interfaces;
+    if (useConversionAttrs) {
+      this->interfaces = interfaces;
+    } else {
+      this->patterns =
+          std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+      this->target = target;
+      this->typeConverter = typeConverter;
+    }
     return success();
   }
 
   void runOnOperation() final {
+    // Fast path:
+    if (!useConversionAttrs) {
+      if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+        signalPassFailure();
+      return;
+    }
+    // Slow path with conversion attributes.
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 5cc71178b0f28..252245dfbf541 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -35,6 +35,8 @@ void mlir::populateOpConvertToLLVMConversionPatterns(
     Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
     RewritePatternSet &patterns) {
   auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+  if (!iface)
+    iface = op->getParentOfType<ConvertToLLVMOpInterface>();
   if (!iface)
     return;
   SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index 6e55a56d8c9aa..132eb473cb918 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
 
 // CHECK-LABEL: gpu.module @nvvm_module
 gpu.module @nvvm_module [#nvvm.target] {



More information about the Mlir-commits mailing list