[Mlir-commits] [mlir] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces (PR #99566)
Fabian Mora
llvmlistbot at llvm.org
Sat Nov 23 15:13:50 PST 2024
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/99566
>From f4c3955aaa5228113d3ad0cdb859a890edbadf48 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 25 Jul 2024 20:53:36 +0000
Subject: [PATCH 1/4] [mlir][LLVM] Add the `ConvertToLLVMAttrInterface` and
`ConvertToLLVMOpInterface` interfaces
This patch adds the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available.
The `ConvertToLLVMAttrInterface` interfaces allows attributes to configure conversion to LLVM, including the conversion target, LLVM type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure conversion to LLVM.
The `ConvertToLLVMOpInterface` interface collects all convert to LLVM attributes stored in an operation.
Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
---
mlir/include/mlir/Conversion/CMakeLists.txt | 2 +
.../Conversion/ConvertToLLVM/CMakeLists.txt | 7 ++
.../ConvertToLLVM/ToLLVMInterface.h | 13 +++
.../ConvertToLLVM/ToLLVMInterface.td | 76 +++++++++++++++++
.../mlir/Conversion/GPUCommon/GPUToLLVM.h | 25 ++++++
.../mlir/Conversion/GPUToNVVM/GPUToNVVM.h | 27 ++++++
.../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 4 +
mlir/include/mlir/InitAllExtensions.h | 4 +
.../Conversion/ConvertToLLVM/CMakeLists.txt | 1 +
.../ConvertToLLVM/ConvertToLLVMPass.cpp | 42 ++++++----
.../ConvertToLLVM/ToLLVMInterface.cpp | 17 ++++
.../GPUCommon/GPUToLLVMConversion.cpp | 32 +++++++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 83 ++++++++++++++-----
.../GPUToNVVM/gpu-to-nvvm-target-attr.mlir | 42 ++++++++++
14 files changed, 335 insertions(+), 40 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
create mode 100644 mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
create mode 100644 mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
create mode 100644 mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
create mode 100644 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt
index d212bf3e395e71..9f76ab659215ea 100644
--- a/mlir/include/mlir/Conversion/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/CMakeLists.txt
@@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
add_public_tablegen_target(MLIRConversionPassIncGen)
add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
+
+add_subdirectory(ConvertToLLVM)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..54d7a03fc22dff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -0,0 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td)
+mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 00aeed9bf29dc2..1d14ff30eb5516 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -11,6 +11,7 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
namespace mlir {
class ConversionTarget;
@@ -50,6 +51,18 @@ void populateConversionTargetFromOperation(Operation *op,
LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Helper function for populating LLVM conversion patterns. If `op` implements
+/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern
+/// attributes provided by the interface will be used to configure the
+/// conversion target, type converter, and the pattern set.
+void populateOpConvertToLLVMConversionPatterns(Operation *op,
+ ConversionTarget &target,
+ LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
} // namespace mlir
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc"
+
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
new file mode 100644
index 00000000000000..1331a9802c570f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td
@@ -0,0 +1,76 @@
+
+//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines interfaces for managing transformations, including populating
+// pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMAttrInterface :
+ AttrInterface<"ConvertToLLVMAttrInterface"> {
+ let description = [{
+ The `ConvertToLLVMAttrInterface` attribute interfaces allows using
+ attributes to configure the convert to LLVM infrastructure, this includes:
+ - The conversion target.
+ - The LLVM type converter.
+ - The pattern set.
+
+ This interface permits fined grained configuration of the `convert-to-llvm`
+ process. For example, attributes with target information like
+ `#nvvm.target` or `#rodcl.target` can leverage this interface for populating
+ patterns specific to a particular target.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the dialect conversion target, type converter and pattern set.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"populateConvertToLLVMConversionPatterns",
+ /*args=*/(ins "::mlir::ConversionTarget&":$target,
+ "::mlir::LLVMTypeConverter&":$typeConverter,
+ "::mlir::RewritePatternSet&":$patternSet)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Op interface
+//===----------------------------------------------------------------------===//
+
+def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> {
+ let description = [{
+ Interface for collecting all convert to LLVM attributes stored in an
+ operation. See `ConvertToLLVMAttrInterface` for more information on these
+ attributes.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the provided vector with a list of convert to LLVM attributes
+ to apply.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getConvertToLLVMConversionAttrs",
+ /*args=*/(ins
+ "::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs)
+ >
+ ];
+}
+
+#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
new file mode 100644
index 00000000000000..ad8c39fe676618
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h
@@ -0,0 +1,25 @@
+//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to LLVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace gpu {
+/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP`
+/// operation.
+void registerConvertGpuToLLVMInterface(DialectRegistry ®istry);
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
new file mode 100644
index 00000000000000..6311630a23c8f6
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
@@ -0,0 +1,27 @@
+//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace NVVM {
+/// Registers the `ConvertToLLVMAttrInterface` interface on the
+/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion
+/// target, LLVM type converter, and pattern set for converting GPU operations
+/// to NVVM.
+void registerConvertGpuToNVVMInterface(DialectRegistry ®istry);
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 645e86a4309621..fc7c967f1b62cf 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
/// Configure target to convert from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
+/// Configure the LLVM type convert to convert types and address spaces from the
+/// GPU dialect to NVVM.
+void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
+
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 1f2ef26b450701..14a6a2787b3a5d 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -18,6 +18,8 @@
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -72,6 +74,8 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertOpenMPToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);
+ gpu::registerConvertGpuToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
index de3d850d520c0f..a9b49391e36a70 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface
ToLLVMInterface.cpp
DEPENDS
+ MLIRConvertToLLVMInterfaceIncGen
LINK_LIBS PUBLIC
MLIRIR
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index b2407a258c2719..0f3e852977ada2 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -63,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
/// the injection of conversion patterns.
class ConvertToLLVMPass
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
- std::shared_ptr<const FrozenRewritePatternSet> patterns;
- std::shared_ptr<const ConversionTarget> target;
- std::shared_ptr<const LLVMTypeConverter> typeConverter;
+ std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+ interfaces;
public:
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -75,11 +74,8 @@ class ConvertToLLVMPass
}
LogicalResult initialize(MLIRContext *context) final {
- RewritePatternSet tempPatterns(context);
- auto target = std::make_shared<ConversionTarget>(*context);
- target->addLegalDialect<LLVM::LLVMDialect>();
- auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
-
+ auto interfaces =
+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
if (!filterDialects.empty()) {
// Test mode: Populate only patterns from the specified dialects. Produce
// an error if the dialect is not loaded or does not implement the
@@ -94,8 +90,7 @@ class ConvertToLLVMPass
return emitError(UnknownLoc::get(context))
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialectName << "\n";
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
+ interfaces->push_back(iface);
}
} else {
// Normal mode: Populate all patterns from all dialects that implement the
@@ -106,20 +101,33 @@ class ConvertToLLVMPass
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
+ interfaces->push_back(iface);
}
}
- this->patterns =
- std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
- this->target = target;
- this->typeConverter = typeConverter;
+ this->interfaces = interfaces;
return success();
}
void runOnOperation() final {
- if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ LLVMTypeConverter typeConverter(context);
+
+ // Configure the conversion with dialect level interfaces.
+ for (ConvertToLLVMPatternInterface *iface : *interfaces)
+ iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+ patterns);
+
+ // Configure the conversion attribute interfaces.
+ populateOpConvertToLLVMConversionPatterns(getOperation(), target,
+ typeConverter, patterns);
+
+ // Apply the conversion.
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 3a4e83b2a8838f..5cc71178b0f28d 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -30,3 +30,20 @@ void mlir::populateConversionTargetFromOperation(
patterns);
});
}
+
+void mlir::populateOpConvertToLLVMConversionPatterns(
+ Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+ if (!iface)
+ return;
+ SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
+ iface.getConvertToLLVMConversionAttrs(attrs);
+ for (ConvertToLLVMAttrInterface attr : attrs)
+ attr.populateConvertToLLVMConversionPatterns(target, typeConverter,
+ patterns);
+}
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.cpp.inc"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.cpp.inc"
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c58737..1497d662dcdbdd 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -22,6 +22,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -1762,3 +1763,34 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
}
+
+//===----------------------------------------------------------------------===//
+// GPUModuleOp convert to LLVM op interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct GPUModuleOpConvertToLLVMInterface
+ : public ConvertToLLVMOpInterface::ExternalModel<
+ GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
+ /// Get the conversion patterns from the target attribute.
+ void getConvertToLLVMConversionAttrs(
+ Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
+};
+} // namespace
+
+void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
+ Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
+ auto module = cast<gpu::GPUModuleOp>(op);
+ ArrayAttr targetsAttr = module.getTargetsAttr();
+ // Fail if there are no target attributes or there is more than one target.
+ if (!targetsAttr || targetsAttr.size() != 1)
+ return;
+ if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
+ attrs.push_back(patternAttr);
+}
+
+void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+ gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 04e85c2b337dec..b343cf71e3a2e7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -15,8 +15,10 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -274,29 +276,7 @@ struct LowerGpuOpsToNVVMOpsPass
}
LLVMTypeConverter converter(m.getContext(), options);
- // NVVM uses alloca in the default address space to represent private
- // memory allocations, so drop private annotations. NVVM uses address
- // space 3 for shared memory. NVVM uses the default address space to
- // represent global memory.
- populateGpuMemorySpaceAttributeConversions(
- converter, [](gpu::AddressSpace space) -> unsigned {
- switch (space) {
- case gpu::AddressSpace::Global:
- return static_cast<unsigned>(
- NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- case gpu::AddressSpace::Workgroup:
- return static_cast<unsigned>(
- NVVM::NVVMMemorySpace::kSharedMemorySpace);
- case gpu::AddressSpace::Private:
- return 0;
- }
- llvm_unreachable("unknown address space enum value");
- return 0;
- });
- // Lowering for MMAMatrixType.
- converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
- return convertMMAToLLVMType(type);
- });
+ configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
@@ -332,6 +312,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
}
+void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
+ // NVVM uses alloca in the default address space to represent private
+ // memory allocations, so drop private annotations. NVVM uses address
+ // space 3 for shared memory. NVVM uses the default address space to
+ // represent global memory.
+ populateGpuMemorySpaceAttributeConversions(
+ converter, [](gpu::AddressSpace space) -> unsigned {
+ switch (space) {
+ case gpu::AddressSpace::Global:
+ return static_cast<unsigned>(
+ NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+ case gpu::AddressSpace::Workgroup:
+ return static_cast<unsigned>(
+ NVVM::NVVMMemorySpace::kSharedMemorySpace);
+ case gpu::AddressSpace::Private:
+ return 0;
+ }
+ llvm_unreachable("unknown address space enum value");
+ return 0;
+ });
+ // Lowering for MMAMatrixType.
+ converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+ return convertMMAToLLVMType(type);
+ });
+}
+
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
@@ -467,3 +473,34 @@ void mlir::populateGpuToNVVMConversionPatterns(
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
}
+
+//===----------------------------------------------------------------------===//
+// NVVMTargetAttr convert to LLVM attr interface
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct NVVMTargetConvertToLLVMAttrInterface
+ : public ConvertToLLVMAttrInterface::ExternalModel<
+ NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
+ /// Configure GPU to NVVM.
+ void populateConvertToLLVMConversionPatterns(
+ Attribute attr, ConversionTarget &target,
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
+};
+} // namespace
+
+void NVVMTargetConvertToLLVMAttrInterface::
+ populateConvertToLLVMConversionPatterns(Attribute attr,
+ ConversionTarget &target,
+ LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const {
+ configureGpuToNVVMConversionLegality(target);
+ configureGpuToNVVMTypeConverter(typeConverter);
+ populateGpuToNVVMConversionPatterns(typeConverter, patterns);
+}
+
+void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
+ NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
+ });
+}
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
new file mode 100644
index 00000000000000..6e55a56d8c9aa5
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+
+// CHECK-LABEL: gpu.module @nvvm_module
+gpu.module @nvvm_module [#nvvm.target] {
+ // CHECK-LABEL: llvm.func @kernel_0()
+ func.func @kernel_0() -> index {
+ // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
+ // CHECK: = llvm.sext %{{.*}} : i32 to i64
+ %tIdX = gpu.thread_id x
+ // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+ // CHECK: = llvm.sext %{{.*}} : i32 to i64
+ %laneId = gpu.lane_id
+ %sum = index.add %tIdX, %laneId
+ func.return %sum : index
+ }
+
+// CHECK-LABEL: llvm.func @kernel_1
+// CHECK: (%{{.*}}: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64)
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
+ gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+ gpu.return
+ }
+}
+
+// CHECK-LABEL: gpu.module @nvvm_module_2
+gpu.module @nvvm_module_2 {
+ // CHECK-LABEL: llvm.func @kernel_0()
+ func.func @kernel_0() -> index {
+ // CHECK: = gpu.thread_id x
+ %tIdX = gpu.thread_id x
+ // CHECK: = gpu.lane_id
+ %laneId = gpu.lane_id
+ %sum = index.add %tIdX, %laneId
+ func.return %sum : index
+ }
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK: (%{{.*}}: memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>}
+ gpu.func @kernel_1(%arg0 : memref<f32, #gpu.address_space<global>>) kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+ gpu.return
+ }
+}
>From 97e79915d8b1d96b0c3720a1c8d5265123214bb8 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 29 Jul 2024 18:56:18 +0000
Subject: [PATCH 2/4] add option to control whether to use conversion
attributes
---
mlir/include/mlir/Conversion/Passes.td | 6 +++
.../ConvertToLLVM/ConvertToLLVMPass.cpp | 51 +++++++++++++++++--
.../ConvertToLLVM/ToLLVMInterface.cpp | 2 +
.../GPUToNVVM/gpu-to-nvvm-target-attr.mlir | 2 +-
4 files changed, 55 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d272ba219c6f1..57d187bd6dcb32 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -22,12 +22,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
This is a generic pass to convert to LLVM, it uses the
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
the injection of conversion patterns.
+
+ If `use-conversion-attrs` is set to `true`, the pass will look for
+ `ConvertToLLVMAttrInterface` attributes and use them to further configure
+ the conversion process. Enabling this option incurs in extra overhead.
}];
let constructor = "mlir::createConvertToLLVMPass()";
let options = [
ListOption<"filterDialects", "filter-dialects", "std::string",
"Test conversion patterns of only the specified dialects">,
+ Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
+ "Use op conversion attributes to configure the conversion">,
];
}
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 0f3e852977ada2..a06de03aa5ec42 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -65,6 +65,9 @@ class ConvertToLLVMPass
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
interfaces;
+ std::shared_ptr<const FrozenRewritePatternSet> patterns;
+ std::shared_ptr<const ConversionTarget> target;
+ std::shared_ptr<const LLVMTypeConverter> typeConverter;
public:
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -74,8 +77,22 @@ class ConvertToLLVMPass
}
LogicalResult initialize(MLIRContext *context) final {
- auto interfaces =
- std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+ std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
+ std::shared_ptr<ConversionTarget> target;
+ std::shared_ptr<LLVMTypeConverter> typeConverter;
+ RewritePatternSet tempPatterns(context);
+
+ // Only collect the interfaces if `useConversionAttrs=true` as everything
+ // else must be initialized in `runOnOperation`.
+ if (useConversionAttrs) {
+ interfaces =
+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+ } else {
+ target = std::make_shared<ConversionTarget>(*context);
+ target->addLegalDialect<LLVM::LLVMDialect>();
+ typeConverter = std::make_shared<LLVMTypeConverter>(context);
+ }
+
if (!filterDialects.empty()) {
// Test mode: Populate only patterns from the specified dialects. Produce
// an error if the dialect is not loaded or does not implement the
@@ -90,7 +107,12 @@ class ConvertToLLVMPass
return emitError(UnknownLoc::get(context))
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialectName << "\n";
- interfaces->push_back(iface);
+ if (useConversionAttrs) {
+ interfaces->push_back(iface);
+ continue;
+ }
+ iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+ tempPatterns);
}
} else {
// Normal mode: Populate all patterns from all dialects that implement the
@@ -101,15 +123,34 @@ class ConvertToLLVMPass
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
- interfaces->push_back(iface);
+ if (useConversionAttrs) {
+ interfaces->push_back(iface);
+ continue;
+ }
+ iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+ tempPatterns);
}
}
- this->interfaces = interfaces;
+ if (useConversionAttrs) {
+ this->interfaces = interfaces;
+ } else {
+ this->patterns =
+ std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+ this->target = target;
+ this->typeConverter = typeConverter;
+ }
return success();
}
void runOnOperation() final {
+ // Fast path:
+ if (!useConversionAttrs) {
+ if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+ signalPassFailure();
+ return;
+ }
+ // Slow path with conversion attributes.
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 5cc71178b0f28d..252245dfbf5417 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -35,6 +35,8 @@ void mlir::populateOpConvertToLLVMConversionPatterns(
Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) {
auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
+ if (!iface)
+ iface = op->getParentOfType<ConvertToLLVMOpInterface>();
if (!iface)
return;
SmallVector<ConvertToLLVMAttrInterface, 12> attrs;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index 6e55a56d8c9aa5..132eb473cb9189 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
// CHECK-LABEL: gpu.module @nvvm_module
gpu.module @nvvm_module [#nvvm.target] {
>From f550639aa18f7d3cf0ea0fe30733925da394433a Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 23 Nov 2024 22:53:26 +0000
Subject: [PATCH 3/4] address reviewer comments
---
.../ConvertToLLVM/ToLLVMInterface.h | 37 +++
mlir/include/mlir/Conversion/Passes.td | 8 +-
.../ConvertToLLVM/ConvertToLLVMPass.cpp | 243 +++++++++++-------
.../GPUToNVVM/gpu-to-nvvm-target-attr.mlir | 4 +-
4 files changed, 198 insertions(+), 94 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 1d14ff30eb5516..af8150d1d329e2 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -19,6 +19,7 @@ class LLVMTypeConverter;
class MLIRContext;
class Operation;
class RewritePatternSet;
+class AnalysisManager;
/// Base class for dialect interfaces providing translation to LLVM IR.
/// Dialects that can be translated should provide an implementation of this
@@ -59,6 +60,42 @@ void populateOpConvertToLLVMConversionPatterns(Operation *op,
ConversionTarget &target,
LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Base class for creating the internal implementation of `convert-to-llvm`
+/// passes.
+class ConvertToLLVMPassInterface {
+public:
+ ConvertToLLVMPassInterface(MLIRContext *context,
+ ArrayRef<std::string> filterDialects);
+ virtual ~ConvertToLLVMPassInterface() = default;
+
+ /// Get the dependent dialects used by `convert-to-llvm`.
+ static void getDependentDialects(DialectRegistry ®istry);
+
+ /// Initialize the internal state of the `convert-to-llvm` pass
+ /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
+ /// This method returns whether the initialization process failed.
+ virtual LogicalResult initialize() = 0;
+
+ /// Transform `op` to LLVM with the conversions available in the pass. The
+ /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
+ /// to further configure the conversion process. This method is invoked by
+ /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
+ /// transformation process failed.
+ virtual LogicalResult transform(Operation *op,
+ AnalysisManager manager) const = 0;
+
+protected:
+ /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
+ /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
+ /// then `visitor` is invoked only with the dialects in the `filterDialects`
+ /// list.
+ LogicalResult visitInterfaces(
+ llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
+ MLIRContext *context;
+ /// List of dialects names to use as filters.
+ ArrayRef<std::string> filterDialects;
+};
} // namespace mlir
#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 57d187bd6dcb32..e394bae64e0918 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -23,16 +23,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
the injection of conversion patterns.
- If `use-conversion-attrs` is set to `true`, the pass will look for
+ If `dynamic` is set to `true`, the pass will look for
`ConvertToLLVMAttrInterface` attributes and use them to further configure
- the conversion process. Enabling this option incurs in extra overhead.
+ the conversion process. This option also uses the `DataLayoutAnalysis`
+ analysis to configure the type converter. Enabling this option incurs in
+ extra overhead.
}];
let constructor = "mlir::createConvertToLLVMPass()";
let options = [
ListOption<"filterDialects", "filter-dialects", "std::string",
"Test conversion patterns of only the specified dialects">,
- Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
+ Option<"useDynamic", "dynamic", "bool", "false",
"Use op conversion attributes to configure the conversion">,
];
}
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index a06de03aa5ec42..f9b88d80695aca 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -27,7 +28,6 @@ namespace mlir {
using namespace mlir;
namespace {
-
/// This DialectExtension can be attached to the context, which will invoke the
/// `apply()` method for every loaded dialect. If a dialect implements the
/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
@@ -58,104 +58,82 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
}
};
-/// This is a generic pass to convert to LLVM, it uses the
-/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
-/// the injection of conversion patterns.
-class ConvertToLLVMPass
- : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
- std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
- interfaces;
+//===----------------------------------------------------------------------===//
+// StaticConvertToLLVM
+//===----------------------------------------------------------------------===//
+
+/// Static implementation of the `convert-to-llvm` pass. This version only looks
+/// at dialect interfaces to configure the conversion process.
+struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
+ /// Pattern set with conversions to LLVM.
std::shared_ptr<const FrozenRewritePatternSet> patterns;
+ /// The conversion target.
std::shared_ptr<const ConversionTarget> target;
+ /// The LLVM type converter.
std::shared_ptr<const LLVMTypeConverter> typeConverter;
+ using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
-public:
- using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
- void getDependentDialects(DialectRegistry ®istry) const final {
- registry.insert<LLVM::LLVMDialect>();
- registry.addExtensions<LoadDependentDialectExtension>();
+ /// Configure the conversion to LLVM at pass initialization.
+ LogicalResult initialize() final {
+ auto target = std::make_shared<ConversionTarget>(*context);
+ auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
+ RewritePatternSet tempPatterns(context);
+ target->addLegalDialect<LLVM::LLVMDialect>();
+ // Populate the patterns with the dialect interface.
+ if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
+ iface->populateConvertToLLVMConversionPatterns(
+ *target, *typeConverter, tempPatterns);
+ })))
+ return failure();
+ this->patterns =
+ std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
+ this->target = target;
+ this->typeConverter = typeConverter;
+ return success();
}
- LogicalResult initialize(MLIRContext *context) final {
- std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
- std::shared_ptr<ConversionTarget> target;
- std::shared_ptr<LLVMTypeConverter> typeConverter;
- RewritePatternSet tempPatterns(context);
+ /// Apply the conversion driver.
+ LogicalResult transform(Operation *op, AnalysisManager manager) const final {
+ if (failed(applyPartialConversion(op, *target, *patterns)))
+ return failure();
+ return success();
+ }
+};
- // Only collect the interfaces if `useConversionAttrs=true` as everything
- // else must be initialized in `runOnOperation`.
- if (useConversionAttrs) {
- interfaces =
- std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
- } else {
- target = std::make_shared<ConversionTarget>(*context);
- target->addLegalDialect<LLVM::LLVMDialect>();
- typeConverter = std::make_shared<LLVMTypeConverter>(context);
- }
+//===----------------------------------------------------------------------===//
+// DynamicConvertToLLVM
+//===----------------------------------------------------------------------===//
- if (!filterDialects.empty()) {
- // Test mode: Populate only patterns from the specified dialects. Produce
- // an error if the dialect is not loaded or does not implement the
- // interface.
- for (std::string &dialectName : filterDialects) {
- Dialect *dialect = context->getLoadedDialect(dialectName);
- if (!dialect)
- return emitError(UnknownLoc::get(context))
- << "dialect not loaded: " << dialectName << "\n";
- auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
- if (!iface)
- return emitError(UnknownLoc::get(context))
- << "dialect does not implement ConvertToLLVMPatternInterface: "
- << dialectName << "\n";
- if (useConversionAttrs) {
- interfaces->push_back(iface);
- continue;
- }
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
- }
- } else {
- // Normal mode: Populate all patterns from all dialects that implement the
- // interface.
- for (Dialect *dialect : context->getLoadedDialects()) {
- // First time we encounter this dialect: if it implements the interface,
- // let's populate patterns !
- auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
- if (!iface)
- continue;
- if (useConversionAttrs) {
+/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
+/// the IR to configure the conversion to LLVM.
+struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
+ /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
+ /// to partially configure the conversion process.
+ std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+ interfaces;
+ using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
+
+ /// Collect the dialect interfaces used to configure the conversion process.
+ LogicalResult initialize() final {
+ auto interfaces =
+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
+ // Collect the interfaces.
+ if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
interfaces->push_back(iface);
- continue;
- }
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
- }
- }
-
- if (useConversionAttrs) {
- this->interfaces = interfaces;
- } else {
- this->patterns =
- std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
- this->target = target;
- this->typeConverter = typeConverter;
- }
+ })))
+ return failure();
+ this->interfaces = interfaces;
return success();
}
- void runOnOperation() final {
- // Fast path:
- if (!useConversionAttrs) {
- if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
- signalPassFailure();
- return;
- }
- // Slow path with conversion attributes.
- MLIRContext *context = &getContext();
+ /// Configure the conversion process and apply the conversion driver.
+ LogicalResult transform(Operation *op, AnalysisManager manager) const final {
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
- LLVMTypeConverter typeConverter(context);
+ // Get the data layout analysis.
+ const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
+ LLVMTypeConverter typeConverter(context, &dlAnalysis);
// Configure the conversion with dialect level interfaces.
for (ConvertToLLVMPatternInterface *iface : *interfaces)
@@ -163,18 +141,105 @@ class ConvertToLLVMPass
patterns);
// Configure the conversion attribute interfaces.
- populateOpConvertToLLVMConversionPatterns(getOperation(), target,
- typeConverter, patterns);
+ populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
+ patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ return failure();
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPass
+//===----------------------------------------------------------------------===//
+
+/// This is a generic pass to convert to LLVM, it uses the
+/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
+/// the injection of conversion patterns.
+class ConvertToLLVMPass
+ : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
+ std::shared_ptr<const ConvertToLLVMPassInterface> impl;
+
+public:
+ using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
+ void getDependentDialects(DialectRegistry ®istry) const final {
+ ConvertToLLVMPassInterface::getDependentDialects(registry);
+ }
+
+ LogicalResult initialize(MLIRContext *context) final {
+ std::shared_ptr<ConvertToLLVMPassInterface> impl;
+ // Choose the pass implementation.
+ if (useDynamic)
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ else
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ if (failed(impl->initialize()))
+ return failure();
+ this->impl = impl;
+ return success();
+ }
+
+ void runOnOperation() final {
+ if (failed(impl->transform(getOperation(), getAnalysisManager())))
+ return signalPassFailure();
}
};
} // namespace
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPassInterface
+//===----------------------------------------------------------------------===//
+
+ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
+ MLIRContext *context, ArrayRef<std::string> filterDialects)
+ : context(context), filterDialects(filterDialects) {}
+
+void ConvertToLLVMPassInterface::getDependentDialects(
+ DialectRegistry ®istry) {
+ registry.insert<LLVM::LLVMDialect>();
+ registry.addExtensions<LoadDependentDialectExtension>();
+}
+
+LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
+ llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
+ if (!filterDialects.empty()) {
+ // Test mode: Populate only patterns from the specified dialects. Produce
+ // an error if the dialect is not loaded or does not implement the
+ // interface.
+ for (StringRef dialectName : filterDialects) {
+ Dialect *dialect = context->getLoadedDialect(dialectName);
+ if (!dialect)
+ return emitError(UnknownLoc::get(context))
+ << "dialect not loaded: " << dialectName << "\n";
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ return emitError(UnknownLoc::get(context))
+ << "dialect does not implement ConvertToLLVMPatternInterface: "
+ << dialectName << "\n";
+ visitor(iface);
+ }
+ } else {
+ // Normal mode: Populate all patterns from all dialects that implement the
+ // interface.
+ for (Dialect *dialect : context->getLoadedDialects()) {
+ // First time we encounter this dialect: if it implements the interface,
+ // let's populate patterns !
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+ visitor(iface);
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
void mlir::registerConvertToLLVMDependentDialectLoading(
DialectRegistry ®istry) {
registry.addExtensions<LoadDependentDialectExtension>();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index 132eb473cb9189..ed7fa6508d5ade 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s
// CHECK-LABEL: gpu.module @nvvm_module
gpu.module @nvvm_module [#nvvm.target] {
@@ -7,7 +7,7 @@ gpu.module @nvvm_module [#nvvm.target] {
// CHECK: = nvvm.read.ptx.sreg.tid.x : i32
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%tIdX = gpu.thread_id x
- // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+ // CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%laneId = gpu.lane_id
%sum = index.add %tIdX, %laneId
>From 3bf2f34cc2c43361b73389eb426b71d6254f110a Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 23 Nov 2024 23:07:00 +0000
Subject: [PATCH 4/4] fix shared lib build
---
mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
index a9b49391e36a70..c71711ba2ebedb 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass
LINK_LIBS PUBLIC
MLIRIR
+ MLIRConvertToLLVMInterface
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRPass
More information about the Mlir-commits
mailing list