[Mlir-commits] [mlir] 7c3ae48 - [mlir][spirv] Replace SPIRVOpLowering with OpConversionPattern

Lei Zhang llvmlistbot at llvm.org
Sat Jan 9 05:08:11 PST 2021


Author: Lei Zhang
Date: 2021-01-09T08:04:53-05:00
New Revision: 7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4

URL: https://github.com/llvm/llvm-project/commit/7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4
DIFF: https://github.com/llvm/llvm-project/commit/7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4.diff

LOG: [mlir][spirv] Replace SPIRVOpLowering with OpConversionPattern

The dialect conversion framework was enhanced to handle type
conversion automatically. OpConversionPattern already contains
a pointer to the TypeConverter. There is no need to duplicate it
in a separate subclass. This removes the only reason for a
SPIRVOpLowering subclass. It adapts to use core infrastructure
and simplifies the code.

Also added a utility function to OpConversionPattern for getting
TypeConverter as a certain subclass.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D94080

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index fddf84859bc2..4143091543d6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -63,19 +63,6 @@ class SPIRVTypeConverter : public TypeConverter {
   spirv::TargetEnv targetEnv;
 };
 
-/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
-template <typename SourceOp>
-class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
-public:
-  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
-                  PatternBenefit benefit = 1)
-      : OpConversionPattern<SourceOp>(context, benefit),
-        typeConverter(typeConverter) {}
-
-protected:
-  SPIRVTypeConverter &typeConverter;
-};
-
 /// Appends to a pattern list additional patterns for translating the builtin
 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
 /// interface/ABI; they convert function parameters to be of SPIR-V allowed

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e02cf8fe4c0a..51c7788ffb14 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -341,6 +341,13 @@ class ConversionPattern : public RewritePattern {
   /// does not require type conversion.
   TypeConverter *getTypeConverter() const { return typeConverter; }
 
+  template <typename ConverterTy>
+  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
+                   ConverterTy *>
+  getTypeConverter() const {
+    return static_cast<ConverterTy *>(typeConverter);
+  }
+
 protected:
   /// See `RewritePattern::RewritePattern` for information on the other
   /// available constructors.

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index e84269e9418d..d66f9c66c1da 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringSwitch.h"
 
 using namespace mlir;
 
@@ -26,9 +28,9 @@ namespace {
 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
 /// builtin variables.
 template <typename SourceOp, spirv::BuiltIn builtin>
-class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
+class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
 public:
-  using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
@@ -38,9 +40,9 @@ class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
 /// builtin variables.
 template <typename SourceOp, spirv::BuiltIn builtin>
-class SingleDimLaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
+class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
 public:
-  using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
@@ -51,9 +53,9 @@ class SingleDimLaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
 /// a constant with WorkgroupSize decoration. So here we cannot generate a
 /// builtin variable; instead the information in the `spv.entry_point_abi`
 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
-class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
+class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
 public:
-  using SPIRVOpLowering<gpu::BlockDimOp>::SPIRVOpLowering;
+  using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
@@ -61,9 +63,9 @@ class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
 };
 
 /// Pattern to convert a kernel function in GPU dialect within a spv.module.
-class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
+class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
 public:
-  using SPIRVOpLowering<gpu::GPUFuncOp>::SPIRVOpLowering;
+  using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
@@ -74,9 +76,9 @@ class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
 };
 
 /// Pattern to convert a gpu.module to a spv.module.
-class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
+class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
 public:
-  using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering;
+  using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
@@ -85,9 +87,9 @@ class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
 
 /// Pattern to convert a gpu.return into a SPIR-V return.
 // TODO: This can go to DRR when GPU return has operands.
-class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
+class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
 public:
-  using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
+  using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
@@ -102,17 +104,14 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
 
 static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
   auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
-  if (!dimAttr) {
-    return {};
-  }
-  if (dimAttr.getValue() == "x") {
-    return 0;
-  } else if (dimAttr.getValue() == "y") {
-    return 1;
-  } else if (dimAttr.getValue() == "z") {
-    return 2;
-  }
-  return {};
+  if (!dimAttr)
+    return llvm::None;
+
+  return llvm::StringSwitch<Optional<int32_t>>(dimAttr.getValue())
+      .Case("x", 0)
+      .Case("y", 1)
+      .Case("z", 2)
+      .Default(llvm::None);
 }
 
 template <typename SourceOp, spirv::BuiltIn builtin>
@@ -150,7 +149,8 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
 
   auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
   auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
-  auto convertedType = typeConverter.convertType(op.getResult().getType());
+  auto convertedType =
+      getTypeConverter()->convertType(op.getResult().getType());
   if (!convertedType)
     return failure();
   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
@@ -164,7 +164,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
 
 // Legalizes a GPU function as an entry SPIR-V function.
 static spirv::FuncOp
-lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
+lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
                      ConversionPatternRewriter &rewriter,
                      spirv::EntryPointABIAttr entryPointInfo,
                      ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
@@ -266,7 +266,7 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
     return failure();
   }
   spirv::FuncOp newFuncOp = lowerAsEntryFunction(
-      funcOp, typeConverter, rewriter, entryPointAttr, argABI);
+      funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
   if (!newFuncOp)
     return failure();
   newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
@@ -344,5 +344,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
                                       spirv::BuiltIn::NumSubgroups>,
       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
                                       spirv::BuiltIn::SubgroupSize>,
-      WorkGroupSizeConversion>(context, typeConverter);
+      WorkGroupSizeConversion>(typeConverter, context);
 }

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 8133a37aa7ad..0db760b17d7c 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
@@ -44,10 +45,9 @@ namespace {
 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
 /// that the linalg.generic op is performing reduction with a workload size that
 /// can fit in one workgroup.
-class SingleWorkgroupReduction final
-    : public SPIRVOpLowering<linalg::GenericOp> {
-public:
-  using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
+struct SingleWorkgroupReduction final
+    : public OpConversionPattern<linalg::GenericOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   /// Matches the given linalg.generic op as performing reduction and returns
   /// the binary op kind if successful.
@@ -142,9 +142,11 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 
   // TODO: Load to Workgroup storage class first.
 
+  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+
   // Get the input element accessed by this invocation.
   Value inputElementPtr = spirv::getElementPtr(
-      typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
+      *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
   Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
 
   // Perform the group reduction operation.
@@ -163,10 +165,10 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 
   // Get the output element accessed by this reduction.
   Value zero = spirv::ConstantOp::getZero(
-      typeConverter.getIndexType(rewriter.getContext()), loc, rewriter);
+      typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
   SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
   Value outputElementPtr =
-      spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
+      spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
                            zeroIndices, loc, rewriter);
 
   // Write out the final reduction result. This should be only conducted by one
@@ -204,5 +206,5 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
                                          SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
-  patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
+  patterns.insert<SingleWorkgroupReduction>(typeConverter, context);
 }

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index da2488db1182..93caa3294408 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -16,9 +16,14 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Context
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 struct ScfToSPIRVContextImpl {
   // Map between the spirv region control flow operation (spv.loop or
@@ -37,20 +42,40 @@ struct ScfToSPIRVContextImpl {
 ScfToSPIRVContext::ScfToSPIRVContext() {
   impl = std::make_unique<ScfToSPIRVContextImpl>();
 }
+
 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
 
+//===----------------------------------------------------------------------===//
+// Pattern Declarations
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// Common class for all vector to GPU patterns.
 template <typename OpTy>
-class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
+class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
 public:
   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
                           ScfToSPIRVContextImpl *scfToSPIRVContext)
-      : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
-        scfToSPIRVContext(scfToSPIRVContext) {}
+      : OpConversionPattern<OpTy>::OpConversionPattern(context),
+        scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
 
 protected:
   ScfToSPIRVContextImpl *scfToSPIRVContext;
+  // FIXME: We explicitly keep a reference of the type converter here instead of
+  // passing it to OpConversionPattern during construction. This effectively
+  // bypasses the conversion framework's automation on type conversion. This is
+  // needed right now because the conversion framework will unconditionally
+  // legalize all types used by SCF ops upon discovering them, for example, the
+  // types of loop carried values. We use SPIR-V variables for those loop
+  // carried values. Depending on the available capabilities, the SPIR-V
+  // variable can be 
diff erent, for example, cooperative matrix or normal
+  // variable. We'd like to detach the conversion of the loop carried values
+  // from the SCF ops (which is mainly a region). So we need to "mark" types
+  // used by SCF ops as legal, if to use the conversion framework for type
+  // conversion. There isn't a straightforward way to do that yet, as when
+  // converting types, ops aren't taken into consideration. Therefore, we just
+  // bypass the framework's type conversion for now.
+  SPIRVTypeConverter &typeConverter;
 };
 
 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
@@ -90,7 +115,6 @@ class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
 /// we load the value from the allocation and use it as the SCF op result.
 template <typename ScfOp, typename OpTy>
 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
-                                  SPIRVTypeConverter &typeConverter,
                                   ConversionPatternRewriter &rewriter,
                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
                                   ArrayRef<Type> returnTypes) {
@@ -117,7 +141,7 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
 }
 
 //===----------------------------------------------------------------------===//
-// scf::ForOp.
+// scf::ForOp
 //===----------------------------------------------------------------------===//
 
 LogicalResult
@@ -196,13 +220,12 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   SmallVector<Type, 8> initTypes;
   for (auto arg : forOperands.initArgs())
     initTypes.push_back(arg.getType());
-  replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
-                        scfToSPIRVContext, initTypes);
+  replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
   return success();
 }
 
 //===----------------------------------------------------------------------===//
-// scf::IfOp.
+// scf::IfOp
 //===----------------------------------------------------------------------===//
 
 LogicalResult
@@ -255,11 +278,15 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
     auto convertedType = typeConverter.convertType(result.getType());
     returnTypes.push_back(convertedType);
   }
-  replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
-                        scfToSPIRVContext, returnTypes);
+  replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
+                        returnTypes);
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// scf::YieldOp
+//===----------------------------------------------------------------------===//
+
 /// Yield is lowered to stores to the VariableOp created during lowering of the
 /// parent region. For loops we also need to update the branch looping back to
 /// the header with the loop carried values.
@@ -290,6 +317,10 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Hooks
+//===----------------------------------------------------------------------===//
+
 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
                                       SPIRVTypeConverter &typeConverter,
                                       ScfToSPIRVContext &scfToSPIRVContext,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 88d0a818b230..4010484a8e89 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -234,9 +234,9 @@ namespace {
 /// to Workgroup memory when the size is constant.  Note that this pattern needs
 /// to be applied in a pass that runs at least at spv.module scope since it wil
 /// ladd global variables into the spv.module.
-class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
+class AllocOpPattern final : public OpConversionPattern<AllocOp> {
 public:
-  using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;
+  using OpConversionPattern<AllocOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
@@ -246,7 +246,7 @@ class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
       return operation.emitError("unhandled allocation type");
 
     // Get the SPIR-V type for the allocation.
-    Type spirvType = typeConverter.convertType(allocType);
+    Type spirvType = getTypeConverter()->convertType(allocType);
 
     // Insert spv.globalVariable for this allocation.
     Operation *parent =
@@ -276,9 +276,9 @@ class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
 
 /// Removed a deallocation if it is a supported allocation. Currently only
 /// removes deallocation if the memory space is workgroup memory.
-class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
+class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
 public:
-  using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;
+  using OpConversionPattern<DeallocOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
@@ -293,15 +293,15 @@ class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
 
 /// Converts unary and binary standard operations to SPIR-V operations.
 template <typename StdOp, typename SPIRVOp>
-class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
+class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
 public:
-  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+  using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() <= 2);
-    auto dstType = this->typeConverter.convertType(operation.getType());
+    auto dstType = this->getTypeConverter()->convertType(operation.getType());
     if (!dstType)
       return failure();
     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
@@ -318,9 +318,9 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
 ///
 /// This cannot be merged into the template unary/binary pattern due to
 /// Vulkan restrictions over spv.SRem and spv.SMod.
-class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
+class SignedRemIOpPattern final : public OpConversionPattern<SignedRemIOp> {
 public:
-  using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
+  using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
@@ -332,16 +332,16 @@ class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
 /// boolean values, SPIR-V uses 
diff erent operations (`SPIRVLogicalOp`). For
 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
 template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
+class BitwiseOpPattern final : public OpConversionPattern<StdOp> {
 public:
-  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+  using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 2);
     auto dstType =
-        this->typeConverter.convertType(operation.getResult().getType());
+        this->getTypeConverter()->convertType(operation.getResult().getType());
     if (!dstType)
       return failure();
     if (isBoolScalarOrVector(operands.front().getType())) {
@@ -356,9 +356,10 @@ class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
 };
 
 /// Converts composite std.constant operation to spv.constant.
-class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
+class ConstantCompositeOpPattern final
+    : public OpConversionPattern<ConstantOp> {
 public:
-  using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+  using OpConversionPattern<ConstantOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
@@ -366,9 +367,9 @@ class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
 };
 
 /// Converts scalar std.constant operation to spv.constant.
-class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
+class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
 public:
-  using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+  using OpConversionPattern<ConstantOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
@@ -376,9 +377,9 @@ class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
 };
 
 /// Converts floating-point comparison operations to SPIR-V ops.
-class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
+class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
 public:
-  using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
+  using OpConversionPattern<CmpFOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
@@ -386,9 +387,9 @@ class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
 };
 
 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
-class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
+class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
 public:
-  using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+  using OpConversionPattern<CmpIOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
@@ -396,9 +397,9 @@ class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
 };
 
 /// Converts integer compare operation to SPIR-V ops.
-class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
+class CmpIOpPattern final : public OpConversionPattern<CmpIOp> {
 public:
-  using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+  using OpConversionPattern<CmpIOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
@@ -406,9 +407,9 @@ class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
 };
 
 /// Converts std.load to spv.Load.
-class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
 public:
-  using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+  using OpConversionPattern<LoadOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
@@ -416,9 +417,9 @@ class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
 };
 
 /// Converts std.load to spv.Load.
-class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+class LoadOpPattern final : public OpConversionPattern<LoadOp> {
 public:
-  using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+  using OpConversionPattern<LoadOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
@@ -426,9 +427,9 @@ class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
 };
 
 /// Converts std.return to spv.Return.
-class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
+class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
 public:
-  using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+  using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
@@ -436,18 +437,18 @@ class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
 };
 
 /// Converts std.select to spv.Select.
-class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
+class SelectOpPattern final : public OpConversionPattern<SelectOp> {
 public:
-  using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
+  using OpConversionPattern<SelectOp>::OpConversionPattern;
   LogicalResult
   matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Converts std.store to spv.Store on integers.
-class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
 public:
-  using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+  using OpConversionPattern<StoreOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
@@ -455,9 +456,9 @@ class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
 };
 
 /// Converts std.store to spv.Store.
-class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+class StoreOpPattern final : public OpConversionPattern<StoreOp> {
 public:
-  using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+  using OpConversionPattern<StoreOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
@@ -466,9 +467,9 @@ class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
 
 /// Converts std.zexti to spv.Select if the type of source is i1 or vector of
 /// i1.
-class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
+class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
 public:
-  using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
+  using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
@@ -477,7 +478,8 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
     if (!isBoolScalarOrVector(srcType))
       return failure();
 
-    auto dstType = this->typeConverter.convertType(op.getResult().getType());
+    auto dstType =
+        this->getTypeConverter()->convertType(op.getResult().getType());
     Location loc = op.getLoc();
     Attribute zeroAttr, oneAttr;
     if (auto vectorType = dstType.dyn_cast<VectorType>()) {
@@ -497,9 +499,9 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
 
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename StdOp, typename SPIRVOp>
-class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
+class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
 public:
-  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+  using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
@@ -509,7 +511,7 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
     if (isBoolScalarOrVector(srcType))
       return failure();
     auto dstType =
-        this->typeConverter.convertType(operation.getResult().getType());
+        this->getTypeConverter()->convertType(operation.getResult().getType());
     if (dstType == srcType) {
       // Due to type conversion, we are seeing the same source and target type.
       // Then we can just erase this operation by forwarding its operand.
@@ -523,9 +525,9 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
 };
 
 /// Converts std.xor to SPIR-V operations.
-class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
+class XOrOpPattern final : public OpConversionPattern<XOrOp> {
 public:
-  using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
+  using OpConversionPattern<XOrOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
@@ -562,7 +564,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
   // std.constant should only have vector or tenor types.
   assert((srcType.isa<VectorType, RankedTensorType>()));
 
-  auto dstType = typeConverter.convertType(srcType);
+  auto dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
     return failure();
 
@@ -645,7 +647,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
   if (!srcType.isIntOrIndexOrFloat())
     return failure();
 
-  Type dstType = typeConverter.convertType(srcType);
+  Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
     return failure();
 
@@ -771,7 +773,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
-        operandType != this->typeConverter.convertType(operandType)) {         \
+        operandType != this->getTypeConverter()->convertType(operandType)) {   \
       return cmpIOp.emitError(                                                 \
           "bitwidth emulation is not implemented yet on unsigned op");         \
     }                                                                          \
@@ -808,6 +810,8 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
   if (!memrefType.getElementType().isSignlessInteger())
     return failure();
+
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   spirv::AccessChainOp accessChainOp =
       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
                            loadOperands.indices(), loc, rewriter);
@@ -881,9 +885,9 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
-  auto loadPtr =
-      spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
-                           loadOperands.indices(), loadOp.getLoc(), rewriter);
+  auto loadPtr = spirv::getElementPtr(
+      *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+      loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
   return success();
 }
@@ -933,6 +937,7 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
     return failure();
 
   auto loc = storeOp.getLoc();
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   spirv::AccessChainOp accessChainOp =
       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
                            storeOperands.indices(), loc, rewriter);
@@ -1010,8 +1015,9 @@ StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
   auto storePtr =
-      spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
-                           storeOperands.indices(), storeOp.getLoc(), rewriter);
+      spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
+                           storeOperands.memref(), storeOperands.indices(),
+                           storeOp.getLoc(), rewriter);
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               storeOperands.value());
   return success();
@@ -1029,7 +1035,7 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
   if (isBoolScalarOrVector(operands.front().getType()))
     return failure();
 
-  auto dstType = typeConverter.convertType(xorOp.getType());
+  auto dstType = getTypeConverter()->convertType(xorOp.getType());
   if (!dstType)
     return failure();
   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
@@ -1096,7 +1102,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
       TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
-      TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
-                                                          typeConverter);
+      TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
+                                                          context);
 }
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a2735e646bec..1509836ef2e2 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -24,8 +24,9 @@ using namespace mlir;
 
 namespace {
 struct VectorBroadcastConvert final
-    : public SPIRVOpLowering<vector::BroadcastOp> {
-  using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -43,8 +44,9 @@ struct VectorBroadcastConvert final
 };
 
 struct VectorExtractOpConvert final
-    : public SPIRVOpLowering<vector::ExtractOp> {
-  using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
+    : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -60,8 +62,10 @@ struct VectorExtractOpConvert final
   }
 };
 
-struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
-  using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
+struct VectorInsertOpConvert final
+    : public OpConversionPattern<vector::InsertOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -78,8 +82,9 @@ struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
 };
 
 struct VectorExtractElementOpConvert final
-    : public SPIRVOpLowering<vector::ExtractElementOp> {
-  using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
+    : public OpConversionPattern<vector::ExtractElementOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(vector::ExtractElementOp extractElementOp,
                   ArrayRef<Value> operands,
@@ -96,8 +101,9 @@ struct VectorExtractElementOpConvert final
 };
 
 struct VectorInsertElementOpConvert final
-    : public SPIRVOpLowering<vector::InsertElementOp> {
-  using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
+    : public OpConversionPattern<vector::InsertElementOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(vector::InsertElementOp insertElementOp,
                   ArrayRef<Value> operands,
@@ -120,5 +126,5 @@ void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
                                          OwningRewritePatternList &patterns) {
   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
                   VectorInsertOpConvert, VectorExtractElementOpConvert,
-                  VectorInsertElementOpConvert>(context, typeConverter);
+                  VectorInsertElementOpConvert>(typeConverter, context);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index aeff47a831ef..9b62b4289c77 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -151,9 +151,10 @@ namespace {
 /// variable ABI attributes attached to function arguments and converts all
 /// function argument uses to those global variables. This is necessary because
 /// Vulkan requires all shader entry points to be of void(void) type.
-class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
+class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
 public:
-  using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
+  using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
@@ -214,7 +215,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     }
     signatureConverter.remapInput(argType.index(), replacement);
   }
-  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
+  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
                                          &signatureConverter)))
     return failure();
 
@@ -246,7 +247,7 @@ void LowerABIAttributesPass::runOnOperation() {
   });
 
   OwningRewritePatternList patterns;
-  patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
+  patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
 
   ConversionTarget target(*context);
   // "Legal" function ops should have no interface variable ABI attributes.

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 9393f3df6425..1c0445290402 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
@@ -459,9 +460,9 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
 namespace {
 /// A pattern for rewriting function signature to convert arguments of functions
 /// to be of valid SPIR-V types.
-class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
+class FuncOpConversion final : public OpConversionPattern<FuncOp> {
 public:
-  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+  using OpConversionPattern<FuncOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
@@ -478,7 +479,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
   for (auto argType : enumerate(fnType.getInputs())) {
-    auto convertedType = typeConverter.convertType(argType.value());
+    auto convertedType = getTypeConverter()->convertType(argType.value());
     if (!convertedType)
       return failure();
     signatureConverter.addInputs(argType.index(), convertedType);
@@ -486,7 +487,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   Type resultType;
   if (fnType.getNumResults() == 1)
-    resultType = typeConverter.convertType(fnType.getResult(0));
+    resultType = getTypeConverter()->convertType(fnType.getResult(0));
 
   // Create the converted spv.func op.
   auto newFuncOp = rewriter.create<spirv::FuncOp>(
@@ -504,8 +505,8 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
-                                         &signatureConverter)))
+  if (failed(rewriter.convertRegionTypes(
+          &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
     return failure();
   rewriter.eraseOp(funcOp);
   return success();
@@ -514,7 +515,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 void mlir::populateBuiltinFuncToSPIRVPatterns(
     MLIRContext *context, SPIRVTypeConverter &typeConverter,
     OwningRewritePatternList &patterns) {
-  patterns.insert<FuncOpConversion>(context, typeConverter);
+  patterns.insert<FuncOpConversion>(typeConverter, context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 850e22465d44..9e972c3a6c57 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // std arithmetic ops
@@ -628,49 +628,59 @@ func @fptosi2(%arg0 : f16) -> i16 {
 
 // -----
 
-// Checks that cast types will be adjusted when no special capabilities for
-// non-32-bit scalar types.
+// Checks that cast types will be adjusted when missing special capabilities for
+// certain non-32-bit scalar types.
 module attributes {
-  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float64], []>, {}>
 } {
 
 // CHECK-LABEL: @fpext1
 // CHECK-SAME: %[[ARG:.*]]: f32
-func @fpext1(%arg0: f16) {
-  // CHECK-NEXT: "use"(%[[ARG]])
+func @fpext1(%arg0: f16) -> f64 {
+  // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
   %0 = std.fpext %arg0 : f16 to f64
-  "use"(%0) : (f64) -> ()
+  return %0: f64
 }
 
 // CHECK-LABEL: @fpext2
 // CHECK-SAME: %[[ARG:.*]]: f32
-func @fpext2(%arg0 : f32) {
-  // CHECK-NEXT: "use"(%[[ARG]])
+func @fpext2(%arg0 : f32) -> f64 {
+  // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
   %0 = std.fpext %arg0 : f32 to f64
-  "use"(%0) : (f64) -> ()
+  return %0: f64
 }
 
+} // end module
+
+// -----
+
+// Checks that cast types will be adjusted when missing special capabilities for
+// certain non-32-bit scalar types.
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}>
+} {
+
 // CHECK-LABEL: @fptrunc1
 // CHECK-SAME: %[[ARG:.*]]: f32
-func @fptrunc1(%arg0 : f64) {
-  // CHECK-NEXT: "use"(%[[ARG]])
+func @fptrunc1(%arg0 : f64) -> f16 {
+  // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
   %0 = std.fptrunc %arg0 : f64 to f16
-  "use"(%0) : (f16) -> ()
+  return %0: f16
 }
 
 // CHECK-LABEL: @fptrunc2
 // CHECK-SAME: %[[ARG:.*]]: f32
-func @fptrunc2(%arg0: f32) {
-  // CHECK-NEXT: "use"(%[[ARG]])
+func @fptrunc2(%arg0: f32) -> f16 {
+  // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
   %0 = std.fptrunc %arg0 : f32 to f16
-  "use"(%0) : (f16) -> ()
+  return %0: f16
 }
 
 // CHECK-LABEL: @sitofp
-func @sitofp(%arg0 : i64) {
+func @sitofp(%arg0 : i64) -> f64 {
   // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32
   %0 = std.sitofp %arg0 : i64 to f64
-  "use"(%0) : (f64) -> ()
+  return %0: f64
 }
 
 } // end module


        


More information about the Mlir-commits mailing list