[Mlir-commits] [mlir] [mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded conversions. (PR #124439)
Ivan Butygin
llvmlistbot at llvm.org
Sat Jan 25 17:59:25 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/124439
Using `ConvertToLLVMPatternInterface` allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface.
>From 35fc50b0bc8363ce5b85a7300c9adb24431a6025 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 1 Jan 2025 17:40:07 +0100
Subject: [PATCH 1/2] [mlir][vector] Create `VectorToLLVMDialectInterface`
Create `VectorToLLVMDialectInterface` which allows automatic conversion discovery by generic `--convert-to-llvm` pass.
This only covers final dialect conversion step and not any previous preparation steps.
Also, currently there is no way to pass any additional parameters through this conversion interface, but most users using default parameters anyway.
---
.../VectorToLLVM/ConvertVectorToLLVM.h | 3 +++
mlir/include/mlir/InitAllExtensions.h | 2 ++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 25 +++++++++++++++++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 ++
.../vector-to-llvm-interface.mlir | 14 +++++++++++
5 files changed, 46 insertions(+)
create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 5fda62e3584c79..1e29bfeb9c3921 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -24,6 +24,9 @@ void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+namespace vector {
+void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
+}
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 14a6a2787b3a5d..887db344ed88b6 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -26,6 +26,7 @@
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
@@ -76,6 +77,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
+ vector::registerConvertVectorToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a1e21cb524bd9a..7dbe24ced593f8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@@ -1933,3 +1934,27 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
patterns.add<VectorMatmulOpConversion>(converter);
patterns.add<VectorFlatTransposeOpConversion>(converter);
}
+
+namespace {
+struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<LLVM::LLVMDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ populateVectorToLLVMConversionPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::vector::registerConvertVectorToLLVMInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ dialect->addInterfaces<VectorToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b495..d4e83ff60d44d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -429,6 +430,7 @@ void VectorDialect::initialize() {
TransferWriteOp>();
declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
+ declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
}
/// Materialize a single constant operation from a given attribute value with
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
new file mode 100644
index 00000000000000..5252bb25ecab54
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -0,0 +1,14 @@
+// Most of the vector lowering is tested in vector-to-llvm.mlir, this file only for the interface smoke test
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector" --split-input-file %s | FileCheck %s
+
+func.func @bitcast_f32_to_i32_vector_0d(%arg0: vector<f32>) -> vector<i32> {
+ %0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
+ return %0 : vector<i32>
+}
+
+// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d
+// CHECK-SAME: %[[ARG_0:.*]]: vector<f32>
+// CHECK: %[[VEC_F32_1D:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VEC_I32_1D:.*]] = llvm.bitcast %[[VEC_F32_1D]] : vector<1xf32> to vector<1xi32>
+// CHECK: %[[VEC_I32_0D:.*]] = builtin.unrealized_conversion_cast %[[VEC_I32_1D]] : vector<1xi32> to vector<i32>
+// CHECK: return %[[VEC_I32_0D]] : vector<i32>
>From 2497c43fba7a246a3efea864b9bdc454ec003b8b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 26 Jan 2025 02:54:12 +0100
Subject: [PATCH 2/2] [mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion
interface instead of hardcoded connversions.
---
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 37 ++++++++++++-------
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 30 ++++++++-------
2 files changed, 40 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 11363a0d60ebfa..669e2651e63fee 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -11,19 +11,14 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -346,6 +341,11 @@ struct LowerGpuOpsToNVVMOpsPass
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
using Base::Base;
+ void getDependentDialects(DialectRegistry ®istry) const override final {
+ Base::getDependentDialects(registry);
+ registerConvertToLLVMDependentDialectLoading(registry);
+ }
+
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
@@ -376,17 +376,24 @@ struct LowerGpuOpsToNVVMOpsPass
LLVMTypeConverter converter(m.getContext(), options);
configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
+ LLVMConversionTarget target(getContext());
+
+ for (Dialect *dialect : getContext().getLoadedDialects()) {
+ if (isa<math::MathDialect>(dialect))
+ continue;
+
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
- arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
- populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
- populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
- populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
if (this->hasRedux)
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
- LLVMConversionTarget target(getContext());
configureGpuToNVVMConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
@@ -472,8 +479,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
+
+ // Set higher benefit, so patterns will run before generic LLVM lowering.
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
- converter);
+ converter, /*benefit*/ 10);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index afebded1c3ea40..2c281e580754e6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
@@ -19,8 +18,8 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -28,8 +27,6 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
this->runtime = runtime;
}
+ void getDependentDialects(DialectRegistry ®istry) const override final {
+ Base::getDependentDialects(registry);
+ registerConvertToLLVMDependentDialectLoading(registry);
+ }
+
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
MLIRContext *ctx = m.getContext();
@@ -289,18 +291,20 @@ struct LowerGpuOpsToROCDLOpsPass
});
RewritePatternSet llvmPatterns(ctx);
+ LLVMConversionTarget target(getContext());
+
+ for (Dialect *dialect : ctx->getLoadedDialects()) {
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
- mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
*maybeChipset);
- populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
- populateMathToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
- populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
- populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
- LLVMConversionTarget target(getContext());
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
More information about the Mlir-commits
mailing list