[Mlir-commits] [mlir] [mlir] Init the `TransformsInterfaces` for configuring transformations (PR #99566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 18 14:14:43 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-gpu
Author: Fabian Mora (fabianmcg)
<details>
<summary>Changes</summary>
This patch adds the `ConversionPatternsAttrInterface` and `OpWithTransformAttrsOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available.
The `ConversionPatternsAttrInterface` allows attributes to configure the dialect conversion infrastructure, including the conversion target, type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure dialect conversion.
The `OpWithTransformAttrsOpInterface` allows interacting with transforms attributes. These attributes allow configuring transformations like dialect conversion with information present in the IR.
Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
---
Patch is 31.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99566.diff
18 Files Affected:
- (added) mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h (+26)
- (modified) mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h (+4)
- (added) mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h (+39)
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+1)
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+3-1)
- (modified) mlir/include/mlir/InitAllExtensions.h (+2)
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+8)
- (added) mlir/include/mlir/Interfaces/TransformsInterfaces.h (+71)
- (added) mlir/include/mlir/Interfaces/TransformsInterfaces.td (+77)
- (modified) mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp (+27-17)
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+64-23)
- (modified) mlir/lib/Conversion/LLVMCommon/CMakeLists.txt (+2)
- (added) mlir/lib/Conversion/LLVMCommon/ConversionAttrOptions.cpp (+27)
- (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+11)
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
- (added) mlir/lib/Interfaces/TransformsInterfaces.cpp (+53)
- (added) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir (+42)
``````````diff
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
new file mode 100644
index 0000000000000..e076132db2a02
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h
@@ -0,0 +1,26 @@
+//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares registration functions for converting GPU to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
+
+namespace mlir {
+class DialectRegistry;
+namespace NVVM {
+/// Registers the `ConversionPatternsAttrInterface` interface on the
+/// `NVVM::NVVMTargetAttr`. This interface populates the conversion target,
+/// LLVM type converter, and pattern set for converting GPU operations to NVVM.
+void registerConvertGpuToNVVMAttrInterface(DialectRegistry ®istry);
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index e0f4c71051e50..61741a2678c7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
/// Configure target to convert from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
+/// Configure the LLVM type convert to convert types and address spaces from the
+/// GPU dialect to NVVM.
+void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
+
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h
new file mode 100644
index 0000000000000..8941580efaf3b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/ConversionAttrOptions.h
@@ -0,0 +1,39 @@
+//===- ConversionAttrOptions.h - LLVM conversion options --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares convert to LLVM options for `ConversionPatternAttr`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
+#define MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
+
+#include "mlir/Interfaces/TransformsInterfaces.h"
+
+namespace mlir {
+class LLVMTypeConverter;
+
+/// Class for passing convert to LLVM options to `ConversionPatternAttr`
+/// attributes.
+class LLVMConversionPatternAttrOptions : public ConversionPatternAttrOptions {
+public:
+ LLVMConversionPatternAttrOptions(ConversionTarget &target,
+ LLVMTypeConverter &converter);
+
+ static bool classof(ConversionPatternAttrOptions const *opts) {
+ return opts->getTypeID() == TypeID::get<LLVMConversionPatternAttrOptions>();
+ }
+
+ /// Get the LLVM type converter.
+ LLVMTypeConverter &getLLVMTypeConverter();
+};
+} // namespace mlir
+
+MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::LLVMConversionPatternAttrOptions)
+
+#endif // MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a84..4865f485c20d7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -28,6 +28,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/TransformsInterfaces.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c57d291552e60..9d4a4e8ba8553 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -29,6 +29,7 @@ include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TransformsInterfaces.td"
//===----------------------------------------------------------------------===//
// GPU Dialect operations.
@@ -1347,7 +1348,8 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
def GPU_GPUModuleOp : GPU_Op<"module", [
DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
- SymbolTable, Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
+ DeclareOpInterfaceMethods<OpWithTransformAttrsOpInterface>, SymbolTable,
+ Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
]>, Arguments<(ins SymbolNameAttr:$sym_name,
OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler)> {
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..3657e6c47c896 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -18,6 +18,7 @@
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -65,6 +66,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMAttrInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..e941ee862106f 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -36,6 +36,14 @@ mlir_tablegen(DataLayoutTypeInterface.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRDataLayoutInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIRDataLayoutInterfacesIncGen)
+set(LLVM_TARGET_DEFINITIONS TransformsInterfaces.td)
+mlir_tablegen(TransformsAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TransformsAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(TransformsOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TransformsOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTransformsInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRTransformsInterfacesIncGen)
+
add_mlir_doc(DataLayoutInterfaces
DataLayoutAttrInterface
Interfaces/
diff --git a/mlir/include/mlir/Interfaces/TransformsInterfaces.h b/mlir/include/mlir/Interfaces/TransformsInterfaces.h
new file mode 100644
index 0000000000000..d4880f4e6fd68
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/TransformsInterfaces.h
@@ -0,0 +1,71 @@
+//===- TransformsInterfaces.h - Transforms interfaces -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares interfaces for managing transformations, including
+// populating pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_H
+#define MLIR_INTERFACES_TRANSFORMSINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class ConversionTarget;
+class RewritePatternSet;
+class TypeConverter;
+
+/// This class serves as an opaque interface for passing options to the
+/// `ConversionPatternsAttrInterface` methods. Users of this class must
+/// implement the `classof` method as well as using the macros
+/// `MLIR_*_EXPLICIT_TYPE_ID` toensure type safeness.
+class ConversionPatternAttrOptions {
+public:
+ ConversionPatternAttrOptions(ConversionTarget &target,
+ TypeConverter &converter);
+
+ /// Returns the typeID.
+ TypeID getTypeID() const { return typeID; }
+
+ /// Returns a reference to the conversion target to configure.
+ ConversionTarget &getConversionTarget() { return target; }
+
+ /// Returns a reference to the type converter to configure.
+ TypeConverter &getTypeConverter() { return converter; }
+
+protected:
+ /// Derived classes must use this constructor to initialize `typeID` to the
+ /// appropiate value.
+ ConversionPatternAttrOptions(TypeID typeID, ConversionTarget &target,
+ TypeConverter &converter);
+ // Conversion target.
+ ConversionTarget ⌖
+ // Type converter.
+ TypeConverter &converter;
+
+private:
+ TypeID typeID;
+};
+
+/// Helper function for populating dialect conversion patterns. If `op`
+/// implements the `OpWithTransformAttrsOpInterface` interface, then the
+/// conversion pattern attributes provided by the interface will be used to
+/// configure the conversion target, type converter, and the pattern set.
+void populateOpConversionPatterns(Operation *op,
+ ConversionPatternAttrOptions &options,
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#include "mlir/Interfaces/TransformsAttrInterfaces.h.inc"
+
+#include "mlir/Interfaces/TransformsOpInterfaces.h.inc"
+
+MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::ConversionPatternAttrOptions)
+
+#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/TransformsInterfaces.td b/mlir/include/mlir/Interfaces/TransformsInterfaces.td
new file mode 100644
index 0000000000000..4fb9cea95e5e8
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/TransformsInterfaces.td
@@ -0,0 +1,77 @@
+//===- TransformsInterfaces.td - Transforms interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines interfaces for managing transformations, including populating
+// pattern rewrites.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
+#define MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns attribute interface
+//===----------------------------------------------------------------------===//
+
+def ConversionPatternsAttrInterface :
+ AttrInterface<"ConversionPatternsAttrInterface"> {
+ let description = [{
+ This interfaces allows using attributes to configure the dialect conversion
+ infrastructure, this includes:
+ - The conversion target.
+ - The type converter.
+ - The pattern set.
+
+ The conversion target and type converter are passed through the
+ `ConversionPatternAttrOptions` class. Passing them through this class
+ and by reference allows sub-classing the base option class, allowing
+ specializations like `LLVMConversionPatternAttrOptions` for converting to
+ LLVM.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the dialect conversion target, type converter and pattern set.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"populateConversionPatterns",
+ /*args=*/(ins "::mlir::ConversionPatternAttrOptions&":$options,
+ "::mlir::RewritePatternSet&":$patternSet)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operation with patterns interface
+//===----------------------------------------------------------------------===//
+
+def OpWithTransformAttrsOpInterface :
+ OpInterface<"OpWithTransformAttrsOpInterface"> {
+ let description = [{
+ Interface for interacting with transforms attributes. These attributes
+ allow configuring transformations like dialect conversion with information
+ present in the IR.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the provided vector with a list of conversion pattern
+ attributes to apply.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getConversionPatternAttrs",
+ /*args=*/(ins
+ "::llvm::SmallVectorImpl<::mlir::ConversionPatternsAttrInterface>&":$attrs)
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 6135117348a5b..66715912d9647 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -8,10 +8,12 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TransformsInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -61,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
/// the injection of conversion patterns.
class ConvertToLLVMPass
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
- std::shared_ptr<const FrozenRewritePatternSet> patterns;
- std::shared_ptr<const ConversionTarget> target;
- std::shared_ptr<const LLVMTypeConverter> typeConverter;
+ std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
+ interfaces;
public:
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -73,11 +74,8 @@ class ConvertToLLVMPass
}
LogicalResult initialize(MLIRContext *context) final {
- RewritePatternSet tempPatterns(context);
- auto target = std::make_shared<ConversionTarget>(*context);
- target->addLegalDialect<LLVM::LLVMDialect>();
- auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
-
+ auto interfaces =
+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
if (!filterDialects.empty()) {
// Test mode: Populate only patterns from the specified dialects. Produce
// an error if the dialect is not loaded or does not implement the
@@ -92,8 +90,7 @@ class ConvertToLLVMPass
return emitError(UnknownLoc::get(context))
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialectName << "\n";
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
+ interfaces->push_back(iface);
}
} else {
// Normal mode: Populate all patterns from all dialects that implement the
@@ -104,20 +101,33 @@ class ConvertToLLVMPass
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
- iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
- tempPatterns);
+ interfaces->push_back(iface);
}
}
- this->patterns =
- std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
- this->target = target;
- this->typeConverter = typeConverter;
+ this->interfaces = interfaces;
return success();
}
void runOnOperation() final {
- if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ LLVMTypeConverter typeConverter(context);
+
+ // Configure the conversion with dialect level interfaces.
+ for (ConvertToLLVMPatternInterface *iface : *interfaces)
+ iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+ patterns);
+
+ // Configure the conversion attribute interfaces.
+ LLVMConversionPatternAttrOptions opts(target, typeConverter);
+ populateOpConversionPatterns(getOperation(), opts, patterns);
+
+ // Apply the conversion.
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..2a47555d67f28 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -17,6 +17,8 @@
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -261,29 +263,7 @@ struct LowerGpuOpsToNVVMOpsPass
}
LLVMTypeConverter converter(m.getContext(), options);
- // NVVM uses alloca in the default address space to represent private
- // memory allocations, so drop private annotations. NVVM uses address
- // space 3 for shared memory. NVVM uses the default address space to
- // represent global memory.
- populateGpuMemorySpaceAttributeConversions(
- converter, [](gpu::AddressSpace space) -> unsigned {
- switch (space) {
- case gpu::AddressSpace::Global:
- return static_cast<unsigned>(
- NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- case gpu::AddressSpace::Workgroup:
- return static_cast<unsigned>(
- NVVM::NVVMMemorySpace::kSharedMemorySpace);
- case gpu::AddressSpace::Private:
- return 0;
- }
- llvm_unreachable("unknown address space enum value");
- return 0;
- });
- // Lowering for MMAMatrixType.
- converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
- return convertMMAToLLVMType(type);
- });
+ configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
@@ -318,6 +298,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
}
+void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConvert...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/99566
More information about the Mlir-commits
mailing list