[Mlir-commits] [mlir] [mlir][SPIR-V] Update the `ConvertToSPIRV` pass to use dialect interfaces (PR #102046)

Fabian Mora llvmlistbot at llvm.org
Mon Aug 5 12:47:44 PDT 2024


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/102046

This patch updates the base implementation of `ConvertToSPIRV` to be more like the implementation of `ConvertToLLVM`. `ConvertToLLVM` relies on dialect interfaces for configuring the conversion, allowing out-of-tree dialects to participate in the pass if they implement the interface.

This patch introduces the `ConvertToSPIRVPatternInterface` dialect interface, allowing the configuration of the conversion to SPIR-V on a dialect per dialect basis.

Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the `ConvertToSPIRV` pass.

Note:
The convert SCF to SPIR-V was left inside the pass, as it depends on the `ScfToSPIRVContext`, a TODO for a future patch is removing this issue.

>From 3130abbb69fe91e9f7c053c86a92d1d5fc8c06a2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 5 Aug 2024 19:27:27 +0000
Subject: [PATCH] [mlir][SPIR-V] Update the `ConvertToSPIRV` pass to use
 dialect interfaces.

This patch updates the base implementation of `ConvertToSPIRV` to be more like
the implementation of `ConvertToLLVM`. `ConvertToLLVM` relies on dialect
interfaces for configuring the conversion, allowing out-of-tree dialects to
participate in the pass if they implement the interface.

This patch introduces the `ConvertToSPIRVPatternInterface` dialect interface,
allowing the configuration of the conversion to SPIR-V on a dialect per dialect
basis.

Finally, this patch adds the dialect interfaces for all previously supported
dialects in the previous implementation of the `ConvertToSPIRV` pass.

Note:
The convert SCF to SPIR-V was left inside the pass, as it depends on the
`ScfToSPIRVContext`, a TODO for a future patch is removing this issue.
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.h    |   4 +
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.h   |   5 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.h       |   7 +
 .../ConvertToSPIRV/ToSPIRVInterface.h         |  55 ++++++++
 .../mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h |   3 +
 .../Conversion/IndexToSPIRV/IndexToSPIRV.h    |   4 +
 .../mlir/Conversion/MathToSPIRV/MathToSPIRV.h |   3 +
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.h  |   3 +
 mlir/include/mlir/Conversion/Passes.td        |   2 +
 .../mlir/Conversion/UBToSPIRV/UBToSPIRV.h     |   6 +-
 .../Conversion/VectorToSPIRV/VectorToSPIRV.h  |   3 +
 mlir/include/mlir/InitAllExtensions.h         |  18 +++
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  |  38 ++++++
 .../Conversion/ArithToSPIRV/CMakeLists.txt    |   1 +
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp |  30 +++++
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |  11 ++
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 126 ++++++++++++++----
 .../ConvertToSPIRV/ToSPIRVInterface.cpp       |  32 +++++
 .../Conversion/FuncToSPIRV/FuncToSPIRV.cpp    |  29 ++++
 .../Conversion/IndexToSPIRV/IndexToSPIRV.cpp  |  35 +++++
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    |  33 +++++
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           |  34 +++++
 mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp   |  28 ++++
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  31 +++++
 mlir/lib/Dialect/Arith/IR/ArithDialect.cpp    |   2 +
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp |   3 +
 mlir/lib/Dialect/Func/IR/FuncOps.cpp          |   2 +
 mlir/lib/Dialect/Index/IR/IndexDialect.cpp    |   2 +
 mlir/lib/Dialect/Math/IR/MathDialect.cpp      |   2 +
 mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp  |   2 +
 mlir/lib/Dialect/UB/IR/UBOps.cpp              |   2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   2 +
 32 files changed, 530 insertions(+), 28 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
 create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp

diff --git a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
index bb30deb9dc10e..cadf0b2872bea 100644
--- a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
@@ -26,6 +26,10 @@ void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                   RewritePatternSet &patterns);
 
 std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `arith`
+/// dialect.
+void registerConvertArithToSPIRVInterface(DialectRegistry &registry);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
index 43578ffffae2d..276818973c3f8 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
@@ -14,6 +14,7 @@
 #define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
 
 namespace mlir {
+class DialectRegistry;
 class RewritePatternSet;
 class SPIRVTypeConverter;
 
@@ -22,6 +23,10 @@ namespace cf {
 /// ops to SPIR-V ops.
 void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                         RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `cf`
+/// dialect.
+void registerConvertControlFlowToSPIRVInterface(DialectRegistry &registry);
 } // namespace cf
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
index 3852782247527..3062eb5464c53 100644
--- a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
@@ -11,12 +11,19 @@
 
 #include <memory>
 
+#include "mlir/Pass/Pass.h"
+
 namespace mlir {
 class Pass;
+class DialectRegistry;
 
 #define GEN_PASS_DECL_CONVERTTOSPIRVPASS
 #include "mlir/Conversion/Passes.h.inc"
 
+/// Register the extension that will load dependent dialects for SPIR-V
+/// conversion. This is useful to implement a pass similar to
+/// "convert-to-spirv".
+void registerConvertToSPIRVDependentDialectLoading(DialectRegistry &registry);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
new file mode 100644
index 0000000000000..917b81dd237e2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
@@ -0,0 +1,55 @@
+//===- ToSPIRVInterface.h - Conversion to SPIRV iface -*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir {
+class ConversionTarget;
+class SPIRVTypeConverter;
+class MLIRContext;
+class Operation;
+class RewritePatternSet;
+
+/// Base class for dialect interfaces providing translation to SPIR-V.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on SPIR-V
+/// IR. The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class ConvertToSPIRVPatternInterface
+    : public DialectInterface::Base<ConvertToSPIRVPatternInterface> {
+public:
+  ConvertToSPIRVPatternInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Hook for derived dialect interface to load the dialects they
+  /// target. The SPIRVDialect is implicitly already loaded, but this
+  /// method allows to load other intermediate dialects used in the
+  /// conversion.
+  virtual void loadDependentDialects(MLIRContext *context) const {}
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  virtual void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const = 0;
+};
+
+/// Recursively walk the IR and collect all dialects implementing the interface,
+/// and populate the conversion patterns.
+void populateConversionTargetFromOperation(Operation *op,
+                                           ConversionTarget &target,
+                                           SPIRVTypeConverter &typeConverter,
+                                           RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
index 2fa55f40dd970..42711fb6e4b51 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
@@ -24,6 +24,9 @@ class SPIRVTypeConverter;
 void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns);
 
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `func`
+/// dialect.
+void registerConvertFuncToSPIRVInterface(DialectRegistry &registry);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
index 58a1c5246eef9..fad570591983c 100644
--- a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -24,6 +24,10 @@ namespace index {
 void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
                                   RewritePatternSet &patterns);
 std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `index`
+/// dialect.
+void registerConvertIndexToSPIRVInterface(DialectRegistry &registry);
 } // namespace index
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
index 10090268a4663..9a9edc87f3446 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
@@ -23,6 +23,9 @@ class SPIRVTypeConverter;
 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns);
 
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `math`
+/// dialect.
+void registerConvertMathToSPIRVInterface(DialectRegistry &registry);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 54711c8ad727f..77f6cdd2935df 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -70,6 +70,9 @@ void convertMemRefTypesAndAttrs(
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns);
 
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `memref`
+/// dialect.
+void registerConvertMemRefToSPIRVInterface(DialectRegistry &registry);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..6a7d1434dd66d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -45,6 +45,8 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     "vector::VectorDialect",
   ];
   let options = [
+    ListOption<"filterDialects", "filter-dialects", "std::string",
+               "Test conversion patterns of only the specified dialects">,
     Option<"runSignatureConversion", "run-signature-conversion", "bool",
     /*default=*/"true",
     "Run function signature conversion to convert vector types">,
diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
index 3843f2707a520..88cb58df4fc69 100644
--- a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
+++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
@@ -12,7 +12,7 @@
 #include <memory>
 
 namespace mlir {
-
+class DialectRegistry;
 class SPIRVTypeConverter;
 class RewritePatternSet;
 class Pass;
@@ -23,6 +23,10 @@ class Pass;
 namespace ub {
 void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
                                          RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `ub`
+/// dialect.
+void registerConvertUBToSPIRVInterface(DialectRegistry &registry);
 } // namespace ub
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..5184b82c33faf 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -32,6 +32,9 @@ void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 void populateVectorReductionToSPIRVDotProductPatterns(
     RewritePatternSet &patterns);
 
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `vector`
+/// dialect.
+void registerConvertVectorToSPIRVInterface(DialectRegistry &registry);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..d3aab3a0ff8df 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -15,14 +15,22 @@
 #define MLIR_INITALLEXTENSIONS_H_
 
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -66,6 +74,16 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertNVVMToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
 
+  // Register all conversions to SPIR-V extensions.
+  arith::registerConvertArithToSPIRVInterface(registry);
+  cf::registerConvertControlFlowToSPIRVInterface(registry);
+  registerConvertFuncToSPIRVInterface(registry);
+  index::registerConvertIndexToSPIRVInterface(registry);
+  registerConvertMathToSPIRVInterface(registry);
+  registerConvertMemRefToSPIRVInterface(registry);
+  ub::registerConvertUBToSPIRVInterface(registry);
+  registerConvertVectorToSPIRVInterface(registry);
+
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
   bufferization::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index e6c01f063e8b8..603d96462abb5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,7 +9,9 @@
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 
 #include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -1367,3 +1369,39 @@ struct ConvertArithToSPIRVPass
 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
   return std::make_unique<ConvertArithToSPIRVPass>();
 }
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert arith to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+    arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    // Fail hard when there are any remaining 'arith' ops.
+    target.addIllegalDialect<arith::ArithDialect>();
+  }
+};
+} // namespace
+
+void mlir::arith::registerConvertArithToSPIRVInterface(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
index a5385d9cee6af..0ddb1700e4922 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToSPIRV
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRArithTransforms
   MLIRFuncToSPIRV
   MLIRSPIRVConversion
   MLIRSPIRVDialect
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index f96bfd6f788b9..1e701f729e1ea 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
 #include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -114,3 +115,32 @@ void mlir::cf::populateControlFlowToSPIRVPatterns(
 
   patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
 }
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert cf to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    // TODO: We should also take care of block argument type conversion.
+    cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::cf::registerConvertControlFlowToSPIRVInterface(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23f..15ec580a90ba0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   ConvertToSPIRVPass.cpp
+  ToSPIRVInterface.cpp
 )
 
 add_mlir_conversion_library(MLIRConvertToSPIRVPass
@@ -31,3 +32,13 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRVectorToSPIRV
   MLIRVectorTransforms
   )
+
+add_mlir_conversion_library(MLIRConvertToSPIRVInterface
+  ToSPIRVInterface.cpp
+
+  DEPENDS
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSupport
+)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94..9b5780fd95dd0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -7,24 +7,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
-#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
-#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
-#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
-#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <memory>
 
 #define DEBUG_TYPE "convert-to-spirv"
@@ -37,15 +28,91 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+/// This DialectExtension can be attached to the context, which will invoke the
+/// `apply()` method for every loaded dialect. If a dialect implements the
+/// `ConvertToSPIRVPatternInterface` interface, we load dependent dialects
+/// through the interface. This extension is loaded in the context before
+/// starting a pass pipeline that involves dialect conversion to SPIR-V.
+class LoadDependentDialectExtension : public DialectExtensionBase {
+public:
+  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
 
-/// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
-    : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
-  using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
+  void apply(MLIRContext *context,
+             MutableArrayRef<Dialect *> dialects) const final {
+    LLVM_DEBUG(llvm::dbgs() << "Convert to SPIR-V extension load\n");
+    for (Dialect *dialect : dialects) {
+      auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+      if (!iface)
+        continue;
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Convert to SPIR-V found dialect interface for "
+                 << dialect->getNamespace() << "\n");
+      iface->loadDependentDialects(context);
+    }
+  }
 
-  void runOnOperation() override {
-    Operation *op = getOperation();
+  /// Return a copy of this extension.
+  std::unique_ptr<DialectExtensionBase> clone() const final {
+    return std::make_unique<LoadDependentDialectExtension>(*this);
+  }
+};
+
+/// This is a generic pass to convert to SPIR-V, it uses the
+/// `ConvertToSPIRVPatternInterface` dialect interface to delegate to dialects
+/// the injection of conversion patterns.
+class ConvertToSPIRVPass
+    : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+  std::shared_ptr<const SmallVector<ConvertToSPIRVPatternInterface *>>
+      interfaces;
+
+public:
+  using impl::ConvertToSPIRVPassBase<
+      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<spirv::SPIRVDialect>();
+    registry.addExtensions<LoadDependentDialectExtension>();
+  }
+
+  LogicalResult initialize(MLIRContext *context) final {
+    auto interfaces =
+        std::make_shared<SmallVector<ConvertToSPIRVPatternInterface *>>();
+    if (!filterDialects.empty()) {
+      // Test mode: Populate only patterns from the specified dialects. Produce
+      // an error if the dialect is not loaded or does not implement the
+      // interface.
+      for (std::string &dialectName : filterDialects) {
+        Dialect *dialect = context->getLoadedDialect(dialectName);
+        if (!dialect)
+          return emitError(UnknownLoc::get(context))
+                 << "dialect not loaded: " << dialectName << "\n";
+        auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+        if (!iface)
+          return emitError(UnknownLoc::get(context))
+                 << "dialect does not implement "
+                    "ConvertToSPIRVPatternInterface: "
+                 << dialectName << "\n";
+        interfaces->push_back(iface);
+      }
+    } else {
+      // Normal mode: Populate all patterns from all dialects that implement the
+      // interface.
+      for (Dialect *dialect : context->getLoadedDialects()) {
+        // First time we encounter this dialect: if it implements the interface,
+        // let's populate patterns !
+        auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+        if (!iface)
+          continue;
+        interfaces->push_back(iface);
+      }
+    }
+
+    this->interfaces = interfaces;
+    return success();
+  }
+
+  void runOnOperation() final {
     MLIRContext *context = &getContext();
+    Operation *op = getOperation();
 
     // Unroll vectors in function signatures to native size.
     if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
@@ -55,26 +122,31 @@ struct ConvertToSPIRVPass final
     if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
       return signalPassFailure();
 
+    // Lookup the target.
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    // Create and configure the conversion infrastructure.
     std::unique_ptr<ConversionTarget> target =
         SPIRVConversionTarget::get(targetAttr);
     SPIRVTypeConverter typeConverter(targetAttr);
     RewritePatternSet patterns(context);
-    ScfToSPIRVContext scfToSPIRVContext;
 
-    // Populate patterns for each dialect.
-    arith::populateCeilFloorDivExpandOpsPatterns(patterns);
-    arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
-    populateFuncToSPIRVPatterns(typeConverter, patterns);
-    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
-    populateVectorToSPIRVPatterns(typeConverter, patterns);
+    // Configure the conversion with dialect interfaces.
+    for (ConvertToSPIRVPatternInterface *iface : *interfaces)
+      iface->populateConvertToSPIRVConversionPatterns(*target, typeConverter,
+                                                      patterns);
+
+    // TODO: Incorporate SCF to SPIR-V into the interface.
+    ScfToSPIRVContext scfToSPIRVContext;
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
-    ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
 
+    // Apply the conversion.
     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
-      return signalPassFailure();
+      signalPassFailure();
   }
 };
-
 } // namespace
+
+void mlir::registerConvertToSPIRVDependentDialectLoading(
+    DialectRegistry &registry) {
+  registry.addExtensions<LoadDependentDialectExtension>();
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp
new file mode 100644
index 0000000000000..5c631f18cf782
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ToSPIRVInterface.cpp
@@ -0,0 +1,32 @@
+//===- ToSPIRVInterface.cpp - MLIR SPIRV Conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+void mlir::populateConversionTargetFromOperation(
+    Operation *root, ConversionTarget &target,
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
+  DenseSet<Dialect *> dialects;
+  root->walk([&](Operation *op) {
+    Dialect *dialect = op->getDialect();
+    if (!dialects.insert(dialect).second)
+      return;
+    // First time we encounter this dialect: if it implements the interface,
+    // let's populate patterns!
+    auto *iface = dyn_cast<ConvertToSPIRVPatternInterface>(dialect);
+    if (!iface)
+      return;
+    iface->populateConvertToSPIRVConversionPatterns(target, typeConverter,
+                                                    patterns);
+  });
+}
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
index 4740b7cc6c385..a1403a37fce8a 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
 #include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -93,3 +94,31 @@ void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
   patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context);
 }
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert func to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateFuncToSPIRVPatterns(typeConverter, patterns);
+    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertFuncToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index b58efc096e2ea..8b312f1d4c517 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -416,3 +417,37 @@ struct ConvertIndexToSPIRVPass
   }
 };
 } // namespace
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert index to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target.addLegalOp<UnrealizedConversionCastOp>();
+    // Fail hard when there are any remaining 'index' ops.
+    target.addIllegalDialect<index::IndexDialect>();
+
+    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+
+  }
+};
+} // namespace
+
+void mlir::index::registerConvertIndexToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 0b29c93e2d890..03085bbf08188 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -10,7 +10,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
 #include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -449,3 +451,34 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 }
 
 } // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert math to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target.addLegalOp<UnrealizedConversionCastOp>();
+    populateMathToSPIRVPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertMathToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 90b0d727ddee7..512139ca7dbc5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -10,9 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -935,3 +938,34 @@ void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                                     patterns.getContext());
 }
 } // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert memref to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    populateMemRefToSPIRVPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertMemRefToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index 001b7fefb175d..d4a0431419aba 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
@@ -82,3 +83,30 @@ void mlir::ub::populateUBToSPIRVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<PoisonOpLowering>(converter, patterns.getContext());
 }
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert ub to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::ub::registerConvertUBToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, ub::UBDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..9edc341960e27 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -977,3 +978,33 @@ void mlir::populateVectorReductionToSPIRVDotProductPatterns(
     RewritePatternSet &patterns) {
   patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
 }
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert vector to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+  using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<spirv::SPIRVDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToSPIRVConversionPatterns(
+      ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    populateVectorToSPIRVPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertVectorToSPIRVInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+    dialect->addInterfaces<ToSPIRVDialectInterface>();
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 042acf6100900..0fc3b12468c9b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -49,6 +50,7 @@ void arith::ArithDialect::initialize() {
       >();
   addInterfaces<ArithInlinerInterface>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, ArithDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, ArithDialect>();
   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
                            SelectOp>();
   declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConstantOp,
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85..64eec7b34d6fe 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -70,6 +71,8 @@ void ControlFlowDialect::initialize() {
       >();
   addInterfaces<ControlFlowInlinerInterface>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface,
+                           ControlFlowDialect>();
   declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
                             CondBranchOp>();
   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index c719981769b9e..634018057c754 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -43,6 +44,7 @@ void FuncDialect::initialize() {
       >();
   declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, FuncDialect>();
   declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
                             FuncOp, ReturnOp>();
 }
diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
index 183d0e33b2523..45684f7fe5097 100644
--- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 
 using namespace mlir;
 using namespace mlir::index;
@@ -20,6 +21,7 @@ void IndexDialect::initialize() {
   registerAttributes();
   registerOperations();
   declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, IndexDialect>();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
index 285b5ca594050..ec71bc7bd8fe0 100644
--- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -36,4 +37,5 @@ void mlir::math::MathDialect::initialize() {
       >();
   addInterfaces<MathInlinerInterface>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, MathDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, MathDialect>();
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 3a8bd12ba2586..dbaec4efa31a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -48,6 +49,7 @@ void mlir::memref::MemRefDialect::initialize() {
       >();
   addInterfaces<MemRefInlinerInterface>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, MemRefDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, MemRefDialect>();
   declarePromisedInterfaces<bufferization::AllocationOpInterface, AllocOp,
                             AllocaOp, ReallocOp>();
   declarePromisedInterfaces<RuntimeVerifiableOpInterface, CastOp, ExpandShapeOp,
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index 5b2cfe7bf4264..e346f8830a141 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 #include "mlir/IR/Builders.h"
@@ -47,6 +48,7 @@ void UBDialect::initialize() {
       >();
   addInterfaces<UBInlinerInterface>();
   declarePromisedInterface<ConvertToLLVMPatternInterface, UBDialect>();
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, UBDialect>();
 }
 
 Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a3b9f2091ab3..ec75d44138c8b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -407,6 +408,7 @@ void VectorDialect::initialize() {
 
   addInterfaces<VectorInlinerInterface>();
 
+  declarePromisedInterface<ConvertToSPIRVPatternInterface, VectorDialect>();
   declarePromisedInterfaces<bufferization::BufferizableOpInterface,
                             TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
                             YieldOp>();



More information about the Mlir-commits mailing list