[Mlir-commits] [mlir] [mlir][SPIR-V] Update the `ConvertToSPIRV` pass to use dialect interfaces (PR #102046)
Fabian Mora
llvmlistbot at llvm.org
Mon Aug 5 12:47:44 PDT 2024
https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/102046
This patch updates the base implementation of `ConvertToSPIRV` to be more like the implementation of `ConvertToLLVM`. `ConvertToLLVM` relies on dialect interfaces for configuring the conversion, allowing out-of-tree dialects to participate in the pass if they implement the interface.
This patch introduces the `ConvertToSPIRVPatternInterface` dialect interface, allowing the configuration of the conversion to SPIR-V on a dialect per dialect basis.
Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the `ConvertToSPIRV` pass.
Note:
The convert SCF to SPIR-V was left inside the pass, as it depends on the `ScfToSPIRVContext`, a TODO for a future patch is removing this issue.
>From 3130abbb69fe91e9f7c053c86a92d1d5fc8c06a2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 5 Aug 2024 19:27:27 +0000
Subject: [PATCH] [mlir][SPIR-V] Update the `ConvertToSPIRV` pass to use
dialect interfaces.
This patch updates the base implementation of `ConvertToSPIRV` to be more like
the implementation of `ConvertToLLVM`. `ConvertToLLVM` relies on dialect
interfaces for configuring the conversion, allowing out-of-tree dialects to
participate in the pass if they implement the interface.
This patch introduces the `ConvertToSPIRVPatternInterface` dialect interface,
allowing the configuration of the conversion to SPIR-V on a dialect per dialect
basis.
Finally, this patch adds the dialect interfaces for all previously supported
dialects in the previous implementation of the `ConvertToSPIRV` pass.
Note:
The convert SCF to SPIR-V was left inside the pass, as it depends on the
`ScfToSPIRVContext`, a TODO for a future patch is removing this issue.
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.h | 4 +
.../ControlFlowToSPIRV/ControlFlowToSPIRV.h | 5 +
.../ConvertToSPIRV/ConvertToSPIRVPass.h | 7 +
.../ConvertToSPIRV/ToSPIRVInterface.h | 55 ++++++++
.../mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h | 3 +
.../Conversion/IndexToSPIRV/IndexToSPIRV.h | 4 +
.../mlir/Conversion/MathToSPIRV/MathToSPIRV.h | 3 +
.../Conversion/MemRefToSPIRV/MemRefToSPIRV.h | 3 +
mlir/include/mlir/Conversion/Passes.td | 2 +
.../mlir/Conversion/UBToSPIRV/UBToSPIRV.h | 6 +-
.../Conversion/VectorToSPIRV/VectorToSPIRV.h | 3 +
mlir/include/mlir/InitAllExtensions.h | 18 +++
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 38 ++++++
.../Conversion/ArithToSPIRV/CMakeLists.txt | 1 +
.../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 30 +++++
.../Conversion/ConvertToSPIRV/CMakeLists.txt | 11 ++
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 126 ++++++++++++++----
.../ConvertToSPIRV/ToSPIRVInterface.cpp | 32 +++++
.../Conversion/FuncToSPIRV/FuncToSPIRV.cpp | 29 ++++
.../Conversion/IndexToSPIRV/IndexToSPIRV.cpp | 35 +++++
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 33 +++++
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 34 +++++
mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp | 28 ++++
.../VectorToSPIRV/VectorToSPIRV.cpp | 31 +++++
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 2 +
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 3 +
mlir/lib/Dialect/Func/IR/FuncOps.cpp | 2 +
mlir/lib/Dialect/Index/IR/IndexDialect.cpp | 2 +
mlir/lib/Dialect/Math/IR/MathDialect.cpp | 2 +
mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp | 2 +
mlir/lib/Dialect/UB/IR/UBOps.cpp | 2 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +
32 files changed, 530 insertions(+), 28 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp
diff --git a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
index bb30deb9dc10e..cadf0b2872bea 100644
--- a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
@@ -26,6 +26,10 @@ void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `arith`
+/// dialect.
+void registerConvertArithToSPIRVInterface(DialectRegistry ®istry);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
index 43578ffffae2d..276818973c3f8 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
@@ -14,6 +14,7 @@
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
namespace mlir {
+class DialectRegistry;
class RewritePatternSet;
class SPIRVTypeConverter;
@@ -22,6 +23,10 @@ namespace cf {
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `cf`
+/// dialect.
+void registerConvertControlFlowToSPIRVInterface(DialectRegistry ®istry);
} // namespace cf
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
index 3852782247527..3062eb5464c53 100644
--- a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
@@ -11,12 +11,19 @@
#include <memory>
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
class Pass;
+class DialectRegistry;
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
+/// Register the extension that will load dependent dialects for SPIR-V
+/// conversion. This is useful to implement a pass similar to
+/// "convert-to-spirv".
+void registerConvertToSPIRVDependentDialectLoading(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
new file mode 100644
index 0000000000000..917b81dd237e2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
@@ -0,0 +1,55 @@
+//===- ToSPIRVInterface.h - Conversion to SPIRV iface -*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir {
+class ConversionTarget;
+class SPIRVTypeConverter;
+class MLIRContext;
+class Operation;
+class RewritePatternSet;
+
+/// Base class for dialect interfaces providing translation to SPIR-V.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on SPIR-V
+/// IR. The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class ConvertToSPIRVPatternInterface
+ : public DialectInterface::Base<ConvertToSPIRVPatternInterface> {
+public:
+ ConvertToSPIRVPatternInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Hook for derived dialect interface to load the dialects they
+ /// target. The SPIRVDialect is implicitly already loaded, but this
+ /// method allows to load other intermediate dialects used in the
+ /// conversion.
+ virtual void loadDependentDialects(MLIRContext *context) const {}
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ virtual void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const = 0;
+};
+
+/// Recursively walk the IR and collect all dialects implementing the interface,
+/// and populate the conversion patterns.
+void populateConversionTargetFromOperation(Operation *op,
+ ConversionTarget &target,
+ SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
index 2fa55f40dd970..42711fb6e4b51 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
@@ -24,6 +24,9 @@ class SPIRVTypeConverter;
void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `func`
+/// dialect.
+void registerConvertFuncToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
index 58a1c5246eef9..fad570591983c 100644
--- a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -24,6 +24,10 @@ namespace index {
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `index`
+/// dialect.
+void registerConvertIndexToSPIRVInterface(DialectRegistry ®istry);
} // namespace index
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
index 10090268a4663..9a9edc87f3446 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
@@ -23,6 +23,9 @@ class SPIRVTypeConverter;
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `math`
+/// dialect.
+void registerConvertMathToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 54711c8ad727f..77f6cdd2935df 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -70,6 +70,9 @@ void convertMemRefTypesAndAttrs(
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `memref`
+/// dialect.
+void registerConvertMemRefToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..6a7d1434dd66d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -45,6 +45,8 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
"vector::VectorDialect",
];
let options = [
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "Test conversion patterns of only the specified dialects">,
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">,
diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
index 3843f2707a520..88cb58df4fc69 100644
--- a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
+++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
@@ -12,7 +12,7 @@
#include <memory>
namespace mlir {
-
+class DialectRegistry;
class SPIRVTypeConverter;
class RewritePatternSet;
class Pass;
@@ -23,6 +23,10 @@ class Pass;
namespace ub {
void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `ub`
+/// dialect.
+void registerConvertUBToSPIRVInterface(DialectRegistry ®istry);
} // namespace ub
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..5184b82c33faf 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -32,6 +32,9 @@ void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `vector`
+/// dialect.
+void registerConvertVectorToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..d3aab3a0ff8df 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -15,14 +15,22 @@
#define MLIR_INITALLEXTENSIONS_H_
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -66,6 +74,16 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ // Register all conversions to SPIR-V extensions.
+ arith::registerConvertArithToSPIRVInterface(registry);
+ cf::registerConvertControlFlowToSPIRVInterface(registry);
+ registerConvertFuncToSPIRVInterface(registry);
+ index::registerConvertIndexToSPIRVInterface(registry);
+ registerConvertMathToSPIRVInterface(registry);
+ registerConvertMemRefToSPIRVInterface(registry);
+ ub::registerConvertUBToSPIRVInterface(registry);
+ registerConvertVectorToSPIRVInterface(registry);
+
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index e6c01f063e8b8..603d96462abb5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,7 +9,9 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -1367,3 +1369,39 @@ struct ConvertArithToSPIRVPass
std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
return std::make_unique<ConvertArithToSPIRVPass>();
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert arith to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+ arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+
+ // Fail hard when there are any remaining 'arith' ops.
+ target.addIllegalDialect<arith::ArithDialect>();
+ }
+};
+} // namespace
+
+void mlir::arith::registerConvertArithToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
index a5385d9cee6af..0ddb1700e4922 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToSPIRV
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRArithTransforms
MLIRFuncToSPIRV
MLIRSPIRVConversion
MLIRSPIRVDialect
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index f96bfd6f788b9..1e701f729e1ea 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -114,3 +115,32 @@ void mlir::cf::populateControlFlowToSPIRVPatterns(
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert cf to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // TODO: We should also take care of block argument type conversion.
+ cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::cf::registerConvertControlFlowToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23f..15ec580a90ba0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
+ ToSPIRVInterface.cpp
)
add_mlir_conversion_library(MLIRConvertToSPIRVPass
@@ -31,3 +32,13 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRVectorToSPIRV
MLIRVectorTransforms
)
+
+add_mlir_conversion_library(MLIRConvertToSPIRVInterface
+ ToSPIRVInterface.cpp
+
+ DEPENDS
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94..9b5780fd95dd0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -7,24 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
-#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
-#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
-#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
-#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
@@ -37,15 +28,91 @@ namespace mlir {
using namespace mlir;
namespace {
+/// This DialectExtension can be attached to the context, which will invoke the
+/// `apply()` method for every loaded dialect. If a dialect implements the
+/// `ConvertToSPIRVPatternInterface` interface, we load dependent dialects
+/// through the interface. This extension is loaded in the context before
+/// starting a pass pipeline that involves dialect conversion to SPIR-V.
+class LoadDependentDialectExtension : public DialectExtensionBase {
+public:
+ LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
-/// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
- : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
- using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
+ void apply(MLIRContext *context,
+ MutableArrayRef<Dialect *> dialects) const final {
+ LLVM_DEBUG(llvm::dbgs() << "Convert to SPIR-V extension load\n");
+ for (Dialect *dialect : dialects) {
+ auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+ if (!iface)
+ continue;
+ LLVM_DEBUG(llvm::dbgs()
+ << "Convert to SPIR-V found dialect interface for "
+ << dialect->getNamespace() << "\n");
+ iface->loadDependentDialects(context);
+ }
+ }
- void runOnOperation() override {
- Operation *op = getOperation();
+ /// Return a copy of this extension.
+ std::unique_ptr<DialectExtensionBase> clone() const final {
+ return std::make_unique<LoadDependentDialectExtension>(*this);
+ }
+};
+
+/// This is a generic pass to convert to SPIR-V, it uses the
+/// `ConvertToSPIRVPatternInterface` dialect interface to delegate to dialects
+/// the injection of conversion patterns.
+class ConvertToSPIRVPass
+ : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+ std::shared_ptr<const SmallVector<ConvertToSPIRVPatternInterface *>>
+ interfaces;
+
+public:
+ using impl::ConvertToSPIRVPassBase<
+ ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+ void getDependentDialects(DialectRegistry ®istry) const final {
+ registry.insert<spirv::SPIRVDialect>();
+ registry.addExtensions<LoadDependentDialectExtension>();
+ }
+
+ LogicalResult initialize(MLIRContext *context) final {
+ auto interfaces =
+ std::make_shared<SmallVector<ConvertToSPIRVPatternInterface *>>();
+ if (!filterDialects.empty()) {
+ // Test mode: Populate only patterns from the specified dialects. Produce
+ // an error if the dialect is not loaded or does not implement the
+ // interface.
+ for (std::string &dialectName : filterDialects) {
+ Dialect *dialect = context->getLoadedDialect(dialectName);
+ if (!dialect)
+ return emitError(UnknownLoc::get(context))
+ << "dialect not loaded: " << dialectName << "\n";
+ auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+ if (!iface)
+ return emitError(UnknownLoc::get(context))
+ << "dialect does not implement "
+ "ConvertToSPIRVPatternInterface: "
+ << dialectName << "\n";
+ interfaces->push_back(iface);
+ }
+ } else {
+ // Normal mode: Populate all patterns from all dialects that implement the
+ // interface.
+ for (Dialect *dialect : context->getLoadedDialects()) {
+ // First time we encounter this dialect: if it implements the interface,
+ // let's populate patterns !
+ auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+ if (!iface)
+ continue;
+ interfaces->push_back(iface);
+ }
+ }
+
+ this->interfaces = interfaces;
+ return success();
+ }
+
+ void runOnOperation() final {
MLIRContext *context = &getContext();
+ Operation *op = getOperation();
// Unroll vectors in function signatures to native size.
if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
@@ -55,26 +122,31 @@ struct ConvertToSPIRVPass final
if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
return signalPassFailure();
+ // Lookup the target.
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ // Create and configure the conversion infrastructure.
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
- ScfToSPIRVContext scfToSPIRVContext;
- // Populate patterns for each dialect.
- arith::populateCeilFloorDivExpandOpsPatterns(patterns);
- arith::populateArithToSPIRVPatterns(typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
- populateFuncToSPIRVPatterns(typeConverter, patterns);
- index::populateIndexToSPIRVPatterns(typeConverter, patterns);
- populateVectorToSPIRVPatterns(typeConverter, patterns);
+ // Configure the conversion with dialect interfaces.
+ for (ConvertToSPIRVPatternInterface *iface : *interfaces)
+ iface->populateConvertToSPIRVConversionPatterns(*target, typeConverter,
+ patterns);
+
+ // TODO: Incorporate SCF to SPIR-V into the interface.
+ ScfToSPIRVContext scfToSPIRVContext;
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
- ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+ // Apply the conversion.
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
- return signalPassFailure();
+ signalPassFailure();
}
};
-
} // namespace
+
+void mlir::registerConvertToSPIRVDependentDialectLoading(
+ DialectRegistry ®istry) {
+ registry.addExtensions<LoadDependentDialectExtension>();
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp
new file mode 100644
index 0000000000000..5c631f18cf782
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp
@@ -0,0 +1,32 @@
+//===- ToSPIRVInterface.cpp - MLIR SPIRV Conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+void mlir::populateConversionTargetFromOperation(
+ Operation *root, ConversionTarget &target,
+ SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ DenseSet<Dialect *> dialects;
+ root->walk([&](Operation *op) {
+ Dialect *dialect = op->getDialect();
+ if (!dialects.insert(dialect).second)
+ return;
+ // First time we encounter this dialect: if it implements the interface,
+ // let's populate patterns!
+ auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+ if (!iface)
+ return;
+ iface->populateConvertToSPIRVConversionPatterns(target, typeConverter,
+ patterns);
+ });
+}
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
index 4740b7cc6c385..a1403a37fce8a 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -93,3 +94,31 @@ void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context);
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert func to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ populateFuncToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertFuncToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index b58efc096e2ea..8b312f1d4c517 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -8,6 +8,7 @@
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -416,3 +417,37 @@ struct ConvertIndexToSPIRVPass
}
};
} // namespace
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert index to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ // Fail hard when there are any remaining 'index' ops.
+ target.addIllegalDialect<index::IndexDialect>();
+
+ index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+
+ }
+};
+} // namespace
+
+void mlir::index::registerConvertIndexToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 0b29c93e2d890..03085bbf08188 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -10,7 +10,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -449,3 +451,34 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
}
} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert math to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ populateMathToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertMathToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 90b0d727ddee7..512139ca7dbc5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -10,9 +10,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -935,3 +938,34 @@ void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.getContext());
}
} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert memref to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+
+ populateMemRefToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertMemRefToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index 001b7fefb175d..d4a0431419aba 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -8,6 +8,7 @@
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
@@ -82,3 +83,30 @@ void mlir::ub::populateUBToSPIRVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<PoisonOpLowering>(converter, patterns.getContext());
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert ub to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::ub::registerConvertUBToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, ub::UBDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..9edc341960e27 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -977,3 +978,33 @@ void mlir::populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert vector to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertVectorToSPIRVInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 042acf6100900..0fc3b12468c9b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -49,6 +50,7 @@ void arith::ArithDialect::initialize() {
>();
addInterfaces<ArithInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, ArithDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, ArithDialect>();
declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
SelectOp>();
declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConstantOp,
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85..64eec7b34d6fe 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -70,6 +71,8 @@ void ControlFlowDialect::initialize() {
>();
addInterfaces<ControlFlowInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface,
+ ControlFlowDialect>();
declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
CondBranchOp>();
declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index c719981769b9e..634018057c754 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -43,6 +44,7 @@ void FuncDialect::initialize() {
>();
declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, FuncDialect>();
declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
FuncOp, ReturnOp>();
}
diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
index 183d0e33b2523..45684f7fe5097 100644
--- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
using namespace mlir;
using namespace mlir::index;
@@ -20,6 +21,7 @@ void IndexDialect::initialize() {
registerAttributes();
registerOperations();
declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, IndexDialect>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
index 285b5ca594050..ec71bc7bd8fe0 100644
--- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -36,4 +37,5 @@ void mlir::math::MathDialect::initialize() {
>();
addInterfaces<MathInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, MathDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, MathDialect>();
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 3a8bd12ba2586..dbaec4efa31a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -48,6 +49,7 @@ void mlir::memref::MemRefDialect::initialize() {
>();
addInterfaces<MemRefInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, MemRefDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, MemRefDialect>();
declarePromisedInterfaces<bufferization::AllocationOpInterface, AllocOp,
AllocaOp, ReallocOp>();
declarePromisedInterfaces<RuntimeVerifiableOpInterface, CastOp, ExpandShapeOp,
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index 5b2cfe7bf4264..e346f8830a141 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
@@ -47,6 +48,7 @@ void UBDialect::initialize() {
>();
addInterfaces<UBInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, UBDialect>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, UBDialect>();
}
Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a3b9f2091ab3..ec75d44138c8b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -407,6 +408,7 @@ void VectorDialect::initialize() {
addInterfaces<VectorInlinerInterface>();
+ declarePromisedInterface<ConvertToSPIRVPatternInterface, VectorDialect>();
declarePromisedInterfaces<bufferization::BufferizableOpInterface,
TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
YieldOp>();
More information about the Mlir-commits
mailing list