[Mlir-commits] [mlir] [mlir] Add the `TransformsInterfaces` for configuring transformations (PR #99566)

Fabian Mora llvmlistbot at llvm.org
Thu Jul 18 13:58:52 PDT 2024


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/99566

This patch adds the `ConversionPatternsAttrInterface` and `OpWithTransformAttrsOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available.

The `ConversionPatternsAttrInterface` allows attributes to configure the dialect conversion infrastructure, including the conversion target, type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure dialect conversion.

The `OpWithTransformAttrsOpInterface` allows interacting with transforms attributes. These attributes allow configuring transformations like dialect conversion with information present in the IR.

Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.

>From ece4a8f5aec3eeba30d8f9d7538e0bdeb78c521f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 18 Jul 2024 20:13:28 +0000
Subject: [PATCH] [mlir] Add the `TransformsInterfaces` for configuring
 transformations

This patch adds the `ConversionPatternsAttrInterface` and
`OpWithTransformAttrsOpInterface` interfaces. It also modifies the
`convert-to-llvm` pass to use these interfaces when available.

The `ConversionPatternsAttrInterface` allows attributes to configure the dialect
conversion infrastructure, including the conversion target, type converter, and
populating conversion patterns. See the `NVVMTargetAttr` implementation of this
interface for an example of how this interface can be used to configure dialect
conversion.

The `OpWithTransformAttrsOpInterface` allows interacting with transforms
attributes. These attributes allow configuring transformations like dialect
conversion with information present in the IR.

Finally, the `convert-to-llvm` pass was modified to use these interfaces when
available. This allows applying `convert-to-llvm` to GPU modules and letting the
`NVVMTargetAttr` decide which patterns to populate.
---
 .../mlir/Conversion/GPUToNVVM/GPUToNVVM.h     | 26 ++++++
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |  4 +
 .../LLVMCommon/ConversionAttrOptions.h        | 39 +++++++++
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h |  1 +
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  4 +-
 mlir/include/mlir/InitAllExtensions.h         |  2 +
 mlir/include/mlir/Interfaces/CMakeLists.txt   |  8 ++
 .../mlir/Interfaces/TransformsInterfaces.h    | 71 +++++++++++++++
 .../mlir/Interfaces/TransformsInterfaces.td   | 77 ++++++++++++++++
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       | 44 ++++++----
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 87 ++++++++++++++-----
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |  2 +
 .../LLVMCommon/ConversionAttrOptions.cpp      | 27 ++++++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |  1 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 11 +++
 mlir/lib/Interfaces/CMakeLists.txt            |  2 +
 mlir/lib/Interfaces/TransformsInterfaces.cpp  | 53 +++++++++++
 .../GPUToNVVM/gpu-to-nvvm-target-attr.mlir    | 42 +++++++++
 18 files changed, 460 insertions(+), 41 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
 create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h
 create mode 100644 mlir/include/mlir/Interfaces/TransformsInterfaces.h
 create mode 100644 mlir/include/mlir/Interfaces/TransformsInterfaces.td
 create mode 100644 mlir/lib/Conversion/LLVMCommon/ConversionAttrOptions.cpp
 create mode 100644 mlir/lib/Interfaces/TransformsInterfaces.cpp
 create mode 100644 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
new file mode 100644
index 0000000000000..e076132db2a02
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
@@ -0,0 +1,26 @@
+//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace NVVM {
+/// Registers the `ConversionPatternsAttrInterface` interface on the
+/// `NVVM::NVVMTargetAttr`. This interface populates the conversion target,
+/// LLVM type converter, and pattern set for converting GPU operations to NVVM.
+void registerConvertGpuToNVVMAttrInterface(DialectRegistry &registry);
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index e0f4c71051e50..61741a2678c7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 
+/// Configure the LLVM type convert to convert types and address spaces from the
+/// GPU dialect to NVVM.
+void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
+
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h
new file mode 100644
index 0000000000000..8941580efaf3b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h
@@ -0,0 +1,39 @@
+//===- ConversionAttrOptions.h - LLVM conversion options --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares convert to LLVM options for `ConversionPatternAttr`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
+#define MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
+
+#include "mlir/Interfaces/TransformsInterfaces.h"
+
+namespace mlir {
+class LLVMTypeConverter;
+
+/// Class for passing convert to LLVM options to `ConversionPatternAttr`
+/// attributes.
+class LLVMConversionPatternAttrOptions : public ConversionPatternAttrOptions {
+public:
+  LLVMConversionPatternAttrOptions(ConversionTarget &target,
+                                   LLVMTypeConverter &converter);
+
+  static bool classof(ConversionPatternAttrOptions const *opts) {
+    return opts->getTypeID() == TypeID::get<LLVMConversionPatternAttrOptions>();
+  }
+
+  /// Get the LLVM type converter.
+  LLVMTypeConverter &getLLVMTypeConverter();
+};
+} // namespace mlir
+
+MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::LLVMConversionPatternAttrOptions)
+
+#endif // MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a84..4865f485c20d7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -28,6 +28,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/TransformsInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c57d291552e60..9d4a4e8ba8553 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -29,6 +29,7 @@ include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TransformsInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // GPU Dialect operations.
@@ -1347,7 +1348,8 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
 
 def GPU_GPUModuleOp : GPU_Op<"module", [
       DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
-      SymbolTable, Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
+      DeclareOpInterfaceMethods<OpWithTransformAttrsOpInterface>, SymbolTable,
+      Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
     ]>, Arguments<(ins SymbolNameAttr:$sym_name,
           OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
           OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler)> {
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..3657e6c47c896 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -18,6 +18,7 @@
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -65,6 +66,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertMemRefToLLVMInterface(registry);
   registerConvertNVVMToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  NVVM::registerConvertGpuToNVVMAttrInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..e941ee862106f 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -36,6 +36,14 @@ mlir_tablegen(DataLayoutTypeInterface.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIRDataLayoutInterfacesIncGen)
 add_dependencies(mlir-generic-headers MLIRDataLayoutInterfacesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS TransformsInterfaces.td)
+mlir_tablegen(TransformsAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TransformsAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(TransformsOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TransformsOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTransformsInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRTransformsInterfacesIncGen)
+
 add_mlir_doc(DataLayoutInterfaces
   DataLayoutAttrInterface
   Interfaces/
diff --git a/mlir/include/mlir/Interfaces/TransformsInterfaces.h b/mlir/include/mlir/Interfaces/TransformsInterfaces.h
new file mode 100644
index 0000000000000..d4880f4e6fd68
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/TransformsInterfaces.h
@@ -0,0 +1,71 @@
+//===- TransformsInterfaces.h - Transforms interfaces -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares interfaces for managing transformations, including
+// populating pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_H
+#define MLIR_INTERFACES_TRANSFORMSINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class ConversionTarget;
+class RewritePatternSet;
+class TypeConverter;
+
+/// This class serves as an opaque interface for passing options to the
+/// `ConversionPatternsAttrInterface` methods. Users of this class must
+/// implement the `classof` method as well as using the macros
+/// `MLIR_*_EXPLICIT_TYPE_ID` toensure type safeness.
+class ConversionPatternAttrOptions {
+public:
+  ConversionPatternAttrOptions(ConversionTarget &target,
+                               TypeConverter &converter);
+
+  /// Returns the typeID.
+  TypeID getTypeID() const { return typeID; }
+
+  /// Returns a reference to the conversion target to configure.
+  ConversionTarget &getConversionTarget() { return target; }
+
+  /// Returns a reference to the type converter to configure.
+  TypeConverter &getTypeConverter() { return converter; }
+
+protected:
+  /// Derived classes must use this constructor to initialize `typeID` to the
+  /// appropiate value.
+  ConversionPatternAttrOptions(TypeID typeID, ConversionTarget &target,
+                               TypeConverter &converter);
+  // Conversion target.
+  ConversionTarget ⌖
+  // Type converter.
+  TypeConverter &converter;
+
+private:
+  TypeID typeID;
+};
+
+/// Helper function for populating dialect conversion patterns. If `op`
+/// implements the `OpWithTransformAttrsOpInterface` interface, then the
+/// conversion pattern attributes provided by the interface will be used to
+/// configure the conversion target, type converter, and the pattern set.
+void populateOpConversionPatterns(Operation *op,
+                                  ConversionPatternAttrOptions &options,
+                                  RewritePatternSet &patterns);
+} // namespace mlir
+
+#include "mlir/Interfaces/TransformsAttrInterfaces.h.inc"
+
+#include "mlir/Interfaces/TransformsOpInterfaces.h.inc"
+
+MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::ConversionPatternAttrOptions)
+
+#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/TransformsInterfaces.td b/mlir/include/mlir/Interfaces/TransformsInterfaces.td
new file mode 100644
index 0000000000000..4fb9cea95e5e8
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/TransformsInterfaces.td
@@ -0,0 +1,77 @@
+//===- TransformsInterfaces.td - Transforms interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines interfaces for managing transformations, including populating
+// pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
+#define MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns attribute interface
+//===----------------------------------------------------------------------===//
+
+def ConversionPatternsAttrInterface :
+    AttrInterface<"ConversionPatternsAttrInterface"> {
+  let description = [{
+    This interfaces allows using attributes to configure the dialect conversion
+    infrastructure, this includes:
+     - The conversion target.
+     - The type converter.
+     - The pattern set.
+    
+    The conversion target and type converter are passed through the
+    `ConversionPatternAttrOptions` class. Passing them through this class
+    and by reference allows sub-classing the base option class, allowing
+    specializations like `LLVMConversionPatternAttrOptions` for converting to
+    LLVM.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the dialect conversion target, type converter and pattern set.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"populateConversionPatterns",
+      /*args=*/(ins "::mlir::ConversionPatternAttrOptions&":$options,
+                    "::mlir::RewritePatternSet&":$patternSet)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operation with patterns interface
+//===----------------------------------------------------------------------===//
+
+def OpWithTransformAttrsOpInterface :
+    OpInterface<"OpWithTransformAttrsOpInterface"> {
+  let description = [{
+    Interface for interacting with transforms attributes. These attributes
+    allow configuring transformations like dialect conversion with information
+    present in the IR.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the provided vector with a list of conversion pattern
+        attributes to apply.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getConversionPatternAttrs",
+      /*args=*/(ins
+        "::llvm::SmallVectorImpl<::mlir::ConversionPatternsAttrInterface>&":$attrs)
+    >
+  ];
+}
+
+#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 6135117348a5b..66715912d9647 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -8,10 +8,12 @@
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TransformsInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -61,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
 /// the injection of conversion patterns.
 class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
-  std::shared_ptr<const FrozenRewritePatternSet> patterns;
-  std::shared_ptr<const ConversionTarget> target;
-  std::shared_ptr<const LLVMTypeConverter> typeConverter;
+  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+      interfaces;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -73,11 +74,8 @@ class ConvertToLLVMPass
   }
 
   LogicalResult initialize(MLIRContext *context) final {
-    RewritePatternSet tempPatterns(context);
-    auto target = std::make_shared<ConversionTarget>(*context);
-    target->addLegalDialect<LLVM::LLVMDialect>();
-    auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
-
+    auto interfaces =
+        std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
     if (!filterDialects.empty()) {
       // Test mode: Populate only patterns from the specified dialects. Produce
       // an error if the dialect is not loaded or does not implement the
@@ -92,8 +90,7 @@ class ConvertToLLVMPass
           return emitError(UnknownLoc::get(context))
                  << "dialect does not implement ConvertToLLVMPatternInterface: "
                  << dialectName << "\n";
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     } else {
       // Normal mode: Populate all patterns from all dialects that implement the
@@ -104,20 +101,33 @@ class ConvertToLLVMPass
         auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
         if (!iface)
           continue;
-        iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
-                                                       tempPatterns);
+        interfaces->push_back(iface);
       }
     }
 
-    this->patterns =
-        std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
-    this->target = target;
-    this->typeConverter = typeConverter;
+    this->interfaces = interfaces;
     return success();
   }
 
   void runOnOperation() final {
-    if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    LLVMTypeConverter typeConverter(context);
+
+    // Configure the conversion with dialect level interfaces.
+    for (ConvertToLLVMPatternInterface *iface : *interfaces)
+      iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                     patterns);
+
+    // Configure the conversion attribute interfaces.
+    LLVMConversionPatternAttrOptions opts(target, typeConverter);
+    populateOpConversionPatterns(getOperation(), opts, patterns);
+
+    // Apply the conversion.
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..2a47555d67f28 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -261,29 +263,7 @@ struct LowerGpuOpsToNVVMOpsPass
     }
 
     LLVMTypeConverter converter(m.getContext(), options);
-    // NVVM uses alloca in the default address space to represent private
-    // memory allocations, so drop private annotations. NVVM uses address
-    // space 3 for shared memory. NVVM uses the default address space to
-    // represent global memory.
-    populateGpuMemorySpaceAttributeConversions(
-        converter, [](gpu::AddressSpace space) -> unsigned {
-          switch (space) {
-          case gpu::AddressSpace::Global:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-          case gpu::AddressSpace::Workgroup:
-            return static_cast<unsigned>(
-                NVVM::NVVMMemorySpace::kSharedMemorySpace);
-          case gpu::AddressSpace::Private:
-            return 0;
-          }
-          llvm_unreachable("unknown address space enum value");
-          return 0;
-        });
-    // Lowering for MMAMatrixType.
-    converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToLLVMType(type);
-    });
+    configureGpuToNVVMTypeConverter(converter);
     RewritePatternSet llvmPatterns(m.getContext());
 
     arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
@@ -318,6 +298,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
 
+void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
+  // NVVM uses alloca in the default address space to represent private
+  // memory allocations, so drop private annotations. NVVM uses address
+  // space 3 for shared memory. NVVM uses the default address space to
+  // represent global memory.
+  populateGpuMemorySpaceAttributeConversions(
+      converter, [](gpu::AddressSpace space) -> unsigned {
+        switch (space) {
+        case gpu::AddressSpace::Global:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+        case gpu::AddressSpace::Workgroup:
+          return static_cast<unsigned>(
+              NVVM::NVVMMemorySpace::kSharedMemorySpace);
+        case gpu::AddressSpace::Private:
+          return 0;
+        }
+        llvm_unreachable("unknown address space enum value");
+        return 0;
+      });
+  // Lowering for MMAMatrixType.
+  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+    return convertMMAToLLVMType(type);
+  });
+}
+
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
@@ -409,3 +415,38 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                    "__nv_tanh");
   populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
 }
+
+//===----------------------------------------------------------------------===//
+// NVVMTargetAttr conversion patterns attr interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct NVVMTargetPatternsAttrInterface
+    : public ConversionPatternsAttrInterface::ExternalModel<
+          NVVMTargetPatternsAttrInterface, NVVM::NVVMTargetAttr> {
+  /// Configure GPU to NVVM.
+  void populateConversionPatterns(Attribute attr,
+                                  ConversionPatternAttrOptions &options,
+                                  RewritePatternSet &patterns) const;
+};
+} // namespace
+
+void NVVMTargetPatternsAttrInterface::populateConversionPatterns(
+    Attribute attr, ConversionPatternAttrOptions &options,
+    RewritePatternSet &patterns) const {
+  auto *llvmOptions = dyn_cast<LLVMConversionPatternAttrOptions>(&options);
+  // Bail if the options are invalid.
+  if (!llvmOptions)
+    return;
+  configureGpuToNVVMConversionLegality(options.getConversionTarget());
+  configureGpuToNVVMTypeConverter(llvmOptions->getLLVMTypeConverter());
+  populateGpuToNVVMConversionPatterns(llvmOptions->getLLVMTypeConverter(),
+                                      patterns);
+}
+
+void mlir::NVVM::registerConvertGpuToNVVMAttrInterface(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
+    NVVMTargetAttr::attachInterface<NVVMTargetPatternsAttrInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabc..0133c2bc4c257 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_conversion_library(MLIRLLVMCommonConversion
+  ConversionAttrOptions.cpp
   ConversionTarget.cpp
   LoweringOptions.cpp
   MemRefBuilder.cpp
@@ -16,4 +17,5 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRSupport
   MLIRTransforms
+  MLIRTransformsInterfaces
   )
diff --git a/mlir/lib/Conversion/LLVMCommon/ConversionAttrOptions.cpp b/mlir/lib/Conversion/LLVMCommon/ConversionAttrOptions.cpp
new file mode 100644
index 0000000000000..2173111205971
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/ConversionAttrOptions.cpp
@@ -0,0 +1,27 @@
+//===- ConversionAttrOptions.cpp - LLVM conversion options ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines convert to LLVM options for `ConversionPatternAttr`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+
+using namespace mlir;
+
+LLVMConversionPatternAttrOptions::LLVMConversionPatternAttrOptions(
+    ConversionTarget &target, LLVMTypeConverter &converter)
+    : ConversionPatternAttrOptions(
+          TypeID::get<LLVMConversionPatternAttrOptions>(), target, converter) {}
+
+LLVMTypeConverter &LLVMConversionPatternAttrOptions::getLLVMTypeConverter() {
+  return static_cast<LLVMTypeConverter &>(converter);
+}
+
+MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::LLVMConversionPatternAttrOptions)
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 8e4cef5af7e37..75f2a8193ae2d 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRGPUDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRSupport
+  MLIRTransformsInterfaces
   )
 
 add_mlir_dialect_library(MLIRGPUTransforms
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb..f4f2e96ffe889 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1832,6 +1832,17 @@ void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
   targetsAttr = ArrayAttr::get(getContext(), targetsVector);
 }
 
+void GPUModuleOp::getConversionPatternAttrs(
+    SmallVectorImpl<ConversionPatternsAttrInterface> &attrs) {
+  ArrayAttr targetsAttr = getTargetsAttr();
+  // Fail if there are no target attributes or there is more than one.
+  if (!targetsAttr || targetsAttr.size() != 1)
+    return;
+  if (auto patternAttr =
+          dyn_cast<ConversionPatternsAttrInterface>(targetsAttr[0]))
+    attrs.push_back(patternAttr);
+}
+
 //===----------------------------------------------------------------------===//
 // GPUBinaryOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index d3b7bf65ad3e7..be53a368700cf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -18,6 +18,7 @@ set(LLVM_OPTIONAL_SOURCES
   SideEffectInterfaces.cpp
   SubsetOpInterface.cpp
   TilingInterface.cpp
+  TransformsInterfaces.cpp
   ValueBoundsOpInterface.cpp
   VectorInterfaces.cpp
   ViewLikeInterface.cpp
@@ -83,6 +84,7 @@ add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
 add_mlir_interface_library(SideEffectInterfaces)
+add_mlir_interface_library(TransformsInterfaces)
 
 add_mlir_library(MLIRSubsetOpInterface
   SubsetOpInterface.cpp
diff --git a/mlir/lib/Interfaces/TransformsInterfaces.cpp b/mlir/lib/Interfaces/TransformsInterfaces.cpp
new file mode 100644
index 0000000000000..78ff8f124ae48
--- /dev/null
+++ b/mlir/lib/Interfaces/TransformsInterfaces.cpp
@@ -0,0 +1,53 @@
+//===- PopulatePatternsInterfaces.h - Pattern interfaces --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for managing transformations, including
+// populating pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/TransformsInterfaces.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// PatternAttrsOptions
+//===----------------------------------------------------------------------===//
+
+ConversionPatternAttrOptions::ConversionPatternAttrOptions(
+    ConversionTarget &target, TypeConverter &converter)
+    : ConversionPatternAttrOptions(TypeID::get<ConversionPatternAttrOptions>(),
+                                   target, converter) {}
+
+ConversionPatternAttrOptions::ConversionPatternAttrOptions(
+    TypeID typeID, ConversionTarget &target, TypeConverter &converter)
+    : target(target), converter(converter), typeID(typeID) {}
+
+MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::ConversionPatternAttrOptions)
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
+void mlir::populateOpConversionPatterns(Operation *op,
+                                        ConversionPatternAttrOptions &options,
+                                        RewritePatternSet &patterns) {
+  auto iface = dyn_cast<OpWithTransformAttrsOpInterface>(op);
+  if (!iface)
+    return;
+  SmallVector<ConversionPatternsAttrInterface, 12> attrs;
+  iface.getConversionPatternAttrs(attrs);
+  for (ConversionPatternsAttrInterface attr : attrs)
+    attr.populateConversionPatterns(options, patterns);
+}
+
+#include "mlir/Interfaces/TransformsAttrInterfaces.cpp.inc"
+
+#include "mlir/Interfaces/TransformsOpInterfaces.cpp.inc"
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
new file mode 100644
index 0000000000000..6e55a56d8c9aa
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+
+// CHECK-LABEL: gpu.module @nvvm_module
+gpu.module @nvvm_module [#nvvm.target] {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %tIdX = gpu.thread_id x
+    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: llvm.func @kernel_1
+// CHECK: (%{{.*}}: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64)
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>} 
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}
+
+// CHECK-LABEL: gpu.module @nvvm_module_2
+gpu.module @nvvm_module_2 {
+  // CHECK-LABEL: llvm.func @kernel_0()
+  func.func @kernel_0() -> index {
+    // CHECK: = gpu.thread_id x
+    %tIdX = gpu.thread_id x
+    // CHECK: = gpu.lane_id
+    %laneId = gpu.lane_id
+    %sum = index.add %tIdX, %laneId
+    func.return %sum : index
+  }
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK: (%{{.*}}: memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>}
+  gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+    gpu.return
+  }
+}



More information about the Mlir-commits mailing list