[Mlir-commits] [mlir] 0d7ff22 - [mlir] Refactor TypeConverter to add conversions without inheritance

River Riddle llvmlistbot at llvm.org
Tue Feb 18 16:18:48 PST 2020


Author: River Riddle
Date: 2020-02-18T16:17:48-08:00
New Revision: 0d7ff220ed0eeffda3839546ac1f30e98f60128e

URL: https://github.com/llvm/llvm-project/commit/0d7ff220ed0eeffda3839546ac1f30e98f60128e
DIFF: https://github.com/llvm/llvm-project/commit/0d7ff220ed0eeffda3839546ac1f30e98f60128e.diff

LOG: [mlir] Refactor TypeConverter to add conversions without inheritance

Summary:
This revision refactors the TypeConverter class to not use inheritance to add type conversions. It instead moves to a registration based system, where conversion callbacks are added to the converter with `addConversion`. This method takes a conversion callback, which must be convertible to any of the following forms(where `T` is a class derived from `Type`:
* Optional<Type> (T type)
   - This form represents a 1-1 type conversion. It should return nullptr
     or `llvm::None` to signify failure. If `llvm::None` is returned, the
     converter is allowed to try another conversion function to perform
     the conversion.
* Optional<LogicalResult>(T type, SmallVectorImpl<Type> &results)
   - This form represents a 1-N type conversion. It should return
     `failure` or `llvm::None` to signify a failed conversion. If the new
     set of types is empty, the type is removed and any usages of the
     existing value are expected to be removed during conversion. If
     `llvm::None` is returned, the converter is allowed to try another
     conversion function to perform the conversion.

When attempting to convert a type, the TypeConverter walks each of the registered converters starting with the one registered most recently.

Differential Revision: https://reviews.llvm.org/D74584

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/include/mlir/Support/STLExtras.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/TestDialect/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index d0643947ed20..8af0e4fb0b25 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -196,15 +196,26 @@ convert types. Several of these hooks are detailed below:
 ```c++
 class TypeConverter {
  public:
-  /// This hook allows for converting a type. This function should return
-  /// failure if no valid conversion exists, success otherwise. If the new set
-  /// of types is empty, the type is removed and any usages of the existing
-  /// value are expected to be removed during conversion.
-  virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
-
-  /// This hook simplifies defining 1-1 type conversions. This function returns
-  /// the type to convert to on success, and a null type on failure.
-  virtual Type convertType(Type t);
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms(where `T` is a class derived from `Type`:
+  ///   * Optional<Type>(T)
+  ///     - This form represents a 1-1 type conversion. It should return nullptr
+  ///       or `llvm::None` to signify failure. If `llvm::None` is returned, the
+  ///       converter is allowed to try another conversion function to perform
+  ///       the conversion.
+  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///     - This form represents a 1-N type conversion. It should return
+  ///       `failure` or `llvm::None` to signify a failed conversion. If the new
+  ///       set of types is empty, the type is removed and any usages of the
+  ///       existing value are expected to be removed during conversion. If
+  ///       `llvm::None` is returned, the converter is allowed to try another
+  ///       conversion function to perform the conversion.
+  ///
+  /// When attempting to convert a type, e.g. via `convertType`, the
+  /// `TypeConverter` will invoke each of the converters starting with the one
+  /// most recently registered.
+  template <typename ConversionFnT>
+  void addConversion(ConversionFnT &&callback);
 
   /// This hook allows for materializing a conversion from a set of types into
   /// one result type by generating a cast operation of some kind. The generated

diff  --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
index e2e5bd2880a6..4124f3f0e3b2 100644
--- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
+++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
@@ -16,14 +16,8 @@ class MLIRContext;
 class ModuleOp;
 template <typename T> class OpPassBase;
 
-class LinalgTypeConverter : public LLVMTypeConverter {
-public:
-  using LLVMTypeConverter::LLVMTypeConverter;
-  Type convertType(Type t) override;
-};
-
 /// Populate the given list with patterns that convert from Linalg to LLVM.
-void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
+void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns,
                                             MLIRContext *ctx);
 

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 81e3ea8f1735..db3a948f000f 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -78,10 +78,6 @@ class LLVMTypeConverter : public TypeConverter {
   LLVMTypeConverter(MLIRContext *ctx,
                     const LLVMTypeConverterCustomization &custom);
 
-  /// Convert types to LLVM IR.  This calls `convertAdditionalType` to convert
-  /// non-standard or non-builtin types.
-  Type convertType(Type t) override;
-
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
@@ -129,8 +125,6 @@ class LLVMTypeConverter : public TypeConverter {
   LLVM::LLVMDialect *llvmDialect;
 
 private:
-  Type convertStandardType(Type type);
-
   // Convert a function type.  The arguments and results are converted one by
   // one.  Additionally, if the function returns more than one value, pack the
   // results into an LLVM IR structure type so that the converted function type

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 2a70173ec37c..85b42eeea291 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -27,10 +27,7 @@ namespace mlir {
 /// pointers to structs.
 class SPIRVTypeConverter : public TypeConverter {
 public:
-  using TypeConverter::TypeConverter;
-
-  /// Converts the given standard `type` to SPIR-V correspondence.
-  Type convertType(Type type) override;
+  SPIRVTypeConverter();
 
   /// Gets the SPIR-V correspondence for the standard index type.
   static Type getIndexType(MLIRContext *context);

diff  --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 0e6b95ac5bdc..14336aad6a25 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -376,6 +376,10 @@ struct FunctionTraits<ReturnType (*)(Args...), false> {
   template <size_t i>
   using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
 };
+/// Overload for non-class function type references.
+template <typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType (&)(Args...), false>
+    : public FunctionTraits<ReturnType (*)(Args...)> {};
 } // end namespace mlir
 
 // Allow tuples to be usable as DenseMap keys.

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index cd148f2fd2ea..664005402ead 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -91,15 +91,37 @@ class TypeConverter {
     SmallVector<Type, 4> argTypes;
   };
 
-  /// This hook allows for converting a type. This function should return
-  /// failure if no valid conversion exists, success otherwise. If the new set
-  /// of types is empty, the type is removed and any usages of the existing
-  /// value are expected to be removed during conversion.
-  virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms(where `T` is a class derived from `Type`:
+  ///   * Optional<Type>(T)
+  ///     - This form represents a 1-1 type conversion. It should return nullptr
+  ///       or `llvm::None` to signify failure. If `llvm::None` is returned, the
+  ///       converter is allowed to try another conversion function to perform
+  ///       the conversion.
+  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///     - This form represents a 1-N type conversion. It should return
+  ///       `failure` or `llvm::None` to signify a failed conversion. If the new
+  ///       set of types is empty, the type is removed and any usages of the
+  ///       existing value are expected to be removed during conversion. If
+  ///       `llvm::None` is returned, the converter is allowed to try another
+  ///       conversion function to perform the conversion.
+  /// Note: When attempting to convert a type, e.g. via 'convertType', the
+  ///       mostly recently added conversions will be invoked first.
+  template <typename FnT,
+            typename T = typename FunctionTraits<FnT>::template arg_t<0>>
+  void addConversion(FnT &&callback) {
+    registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
+  }
+
+  /// Convert the given type. This function should return failure if no valid
+  /// conversion exists, success otherwise. If the new set of types is empty,
+  /// the type is removed and any usages of the existing value are expected to
+  /// be removed during conversion.
+  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
 
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
-  virtual Type convertType(Type t) { return t; }
+  Type convertType(Type t);
 
   /// Convert the given set of types, filling 'results' as necessary. This
   /// returns failure if the conversion of any of the types fails, success
@@ -138,6 +160,50 @@ class TypeConverter {
                                            Location loc) {
     llvm_unreachable("expected 'materializeConversion' to be overridden");
   }
+
+private:
+  /// The signature of the callback used to convert a type. If the new set of
+  /// types is empty, the type is removed and any usages of the existing value
+  /// are expected to be removed during conversion.
+  using ConversionCallbackFn =
+      std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
+
+  /// Generate a wrapper for the given callback. This allows for accepting
+  /// 
diff erent callback forms, that all compose into a single version.
+  /// With callback of form: `Optional<Type>(T)`
+  template <typename T, typename FnT>
+  std::enable_if_t<is_invocable<FnT, T>::value, ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    return wrapCallback<T>([=](T type, SmallVectorImpl<Type> &results) {
+      if (Optional<Type> resultOpt = callback(type)) {
+        bool wasSuccess = static_cast<bool>(resultOpt.getValue());
+        if (wasSuccess)
+          results.push_back(resultOpt.getValue());
+        return Optional<LogicalResult>(success(wasSuccess));
+      }
+      return Optional<LogicalResult>();
+    });
+  }
+  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
+  template <typename T, typename FnT>
+  std::enable_if_t<!is_invocable<FnT, T>::value, ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    return [=](Type type,
+               SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
+      T derivedType = type.dyn_cast<T>();
+      if (!derivedType)
+        return llvm::None;
+      return callback(derivedType, results);
+    };
+  }
+
+  /// Register a type conversion.
+  void registerConversion(ConversionCallbackFn callback) {
+    conversions.emplace_back(std::move(callback));
+  }
+
+  /// The set of registered conversion functions.
+  SmallVector<ConversionCallbackFn, 4> conversions;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2164c038c124..df0a0535e7ff 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -28,25 +28,6 @@ using namespace mlir;
 
 namespace {
 
-/// Derived type converter for GPU to NVVM lowering. The GPU dialect uses memory
-/// space 5 for private memory attributions, but NVVM represents private
-/// memory allocations as local `alloca`s in the default address space. This
-/// converter drops the private memory space to support the use case above.
-class NVVMTypeConverter : public LLVMTypeConverter {
-public:
-  using LLVMTypeConverter::LLVMTypeConverter;
-
-  Type convertType(Type type) override {
-    auto memref = type.dyn_cast<MemRefType>();
-    if (memref &&
-        memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
-      type = MemRefType::Builder(memref).setMemorySpace(0);
-    }
-
-    return LLVMTypeConverter::convertType(type);
-  }
-};
-
 /// Converts all_reduce op to LLVM/NVVM ops.
 struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
   using AccumulatorFactory =
@@ -683,8 +664,19 @@ class LowerGpuOpsToNVVMOpsPass
 public:
   void runOnOperation() override {
     gpu::GPUModuleOp m = getOperation();
+
+    /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
+    /// space 5 for private memory attributions, but NVVM represents private
+    /// memory allocations as local `alloca`s in the default address space. This
+    /// converter drops the private memory space to support the use case above.
+    LLVMTypeConverter converter(m.getContext());
+    converter.addConversion([&](MemRefType type) -> Optional<Type> {
+      if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
+        return llvm::None;
+      return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
+    });
+
     OwningRewritePatternList patterns;
-    NVVMTypeConverter converter(m.getContext());
     populateStdToLLVMConversionPatterns(converter, patterns);
     populateGpuToNVVMConversionPatterns(converter, patterns);
     ConversionTarget target(getContext());

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index bd83aaf9fef7..72806aa4b7fc 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -73,30 +73,19 @@ static LLVMType getPtrToElementType(T containerType,
       .getPointerTo();
 }
 
-// Convert the given type to the LLVM IR Dialect type.  The following
-// conversions are supported:
-//   - an Index type is converted into an LLVM integer type with pointer
-//     bitwidth (analogous to intptr_t in C);
-//   - an Integer type is converted into an LLVM integer type of the same width;
-//   - an F32 type is converted into an LLVM float type
-//   - a Buffer, Range or View is converted into an LLVM structure type
-//     containing the respective dynamic values.
-static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
+/// Convert the given range descriptor type to the LLVMIR dialect.
+/// Range descriptor contains the range bounds and the step as 64-bit integers.
+///
+/// struct {
+///   int64_t min;
+///   int64_t max;
+///   int64_t step;
+/// };
+static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
   auto *context = t.getContext();
-  auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+  auto int64Ty = converter.convertType(IntegerType::get(64, context))
                      .cast<LLVM::LLVMType>();
-
-  // Range descriptor contains the range bounds and the step as 64-bit integers.
-  //
-  // struct {
-  //   int64_t min;
-  //   int64_t max;
-  //   int64_t step;
-  // };
-  if (t.isa<RangeType>())
-    return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
-
-  return Type();
+  return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
 }
 
 namespace {
@@ -146,7 +135,7 @@ class RangeOpConversion : public ConvertToLLVMPattern {
                   ConversionPatternRewriter &rewriter) const override {
     auto rangeOp = cast<RangeOp>(op);
     auto rangeDescriptorTy =
-        convertLinalgType(rangeOp.getResult().getType(), typeConverter);
+        convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
 
     edsc::ScopedContext context(rewriter, op->getLoc());
 
@@ -418,12 +407,6 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
   return fnNameAttr;
 }
 
-Type LinalgTypeConverter::convertType(Type t) {
-  if (auto result = LLVMTypeConverter::convertType(t))
-    return result;
-  return convertLinalgType(t, *this);
-}
-
 namespace {
 
 // LinalgOpConversion<LinalgOp> creates a new call to the
@@ -555,10 +538,14 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
 
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 void mlir::populateLinalgToLLVMConversionPatterns(
-    LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
     MLIRContext *ctx) {
   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
                   TransposeOpConversion, YieldOpConversion>(ctx, converter);
+
+  // Populate the type conversions for the linalg types.
+  converter.addConversion(
+      [&](RangeType type) { return convertRangeType(type, converter); });
 }
 
 namespace {
@@ -572,7 +559,7 @@ void ConvertLinalgToLLVMPass::runOnModule() {
 
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;
-  LinalgTypeConverter converter(&getContext());
+  LLVMTypeConverter converter(&getContext());
   populateAffineToStdConversionPatterns(patterns, &getContext());
   populateLoopToStdConversionPatterns(patterns, &getContext());
   populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,

diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 2378b6a355cf..88ca986874e4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-
 #include "mlir/ADT/TypeSwitch.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -130,6 +129,19 @@ LLVMTypeConverter::LLVMTypeConverter(
       customizations(customs) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
   module = &llvmDialect->getLLVMModule();
+
+  // Register conversions for the standard types.
+  addConversion([&](FloatType type) { return convertFloatType(type); });
+  addConversion([&](FunctionType type) { return convertFunctionType(type); });
+  addConversion([&](IndexType type) { return convertIndexType(type); });
+  addConversion([&](IntegerType type) { return convertIntegerType(type); });
+  addConversion([&](MemRefType type) { return convertMemRefType(type); });
+  addConversion(
+      [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
+  addConversion([&](VectorType type) { return convertVectorType(type); });
+
+  // LLVMType is legal, so add a pass-through conversion.
+  addConversion([](LLVM::LLVMType type) { return type; });
 }
 
 /// Get the LLVM context.
@@ -359,22 +371,6 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
   return vectorType;
 }
 
-// Dispatch based on the actual type.  Return null type on error.
-Type LLVMTypeConverter::convertStandardType(Type t) {
-  return TypeSwitch<Type, Type>(t)
-      .Case([&](FloatType type) { return convertFloatType(type); })
-      .Case([&](FunctionType type) { return convertFunctionType(type); })
-      .Case([&](IndexType type) { return convertIndexType(type); })
-      .Case([&](IntegerType type) { return convertIntegerType(type); })
-      .Case([&](MemRefType type) { return convertMemRefType(type); })
-      .Case([&](UnrankedMemRefType type) {
-        return convertUnrankedMemRefType(type);
-      })
-      .Case([&](VectorType type) { return convertVectorType(type); })
-      .Case([](LLVM::LLVMType type) { return type; })
-      .Default([](Type) { return Type(); });
-}
-
 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter_,
@@ -2658,9 +2654,6 @@ void mlir::populateStdToLLVMBarePtrConversionPatterns(
   populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
 }
 
-// Convert types using the stored LLVM IR module.
-Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
-
 // Create an LLVM IR structure type if there is more than one result.
 Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   assert(!types.empty() && "expected non-empty list of type");

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 3d11e3c92cd9..773066148e20 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -98,41 +98,37 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
   return llvm::None;
 }
 
-static Type convertStdType(Type type) {
-  // If the type is already valid in SPIR-V, directly return.
-  if (spirv::SPIRVDialect::isValidType(type)) {
-    return type;
-  }
-
-  if (auto indexType = type.dyn_cast<IndexType>()) {
-    return SPIRVTypeConverter::getIndexType(type.getContext());
-  }
-
-  if (auto memRefType = type.dyn_cast<MemRefType>()) {
+SPIRVTypeConverter::SPIRVTypeConverter() {
+  addConversion([](Type type) -> Optional<Type> {
+    // If the type is already valid in SPIR-V, directly return.
+    return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
+  });
+  addConversion([](IndexType indexType) {
+    return SPIRVTypeConverter::getIndexType(indexType.getContext());
+  });
+  addConversion([this](MemRefType memRefType) -> Type {
     // TODO(ravishankarm): For now only support default memory space. The memory
     // space description is not set is stone within MLIR, i.e. it depends on the
     // context it is being used. To map this to SPIR-V storage classes, we
     // should rely on the ABI attributes, and not on the memory space. This is
     // still evolving, and needs to be revisited when there is more clarity.
-    if (memRefType.getMemorySpace()) {
+    if (memRefType.getMemorySpace())
       return Type();
-    }
 
-    auto elementType = convertStdType(memRefType.getElementType());
-    if (!elementType) {
+    auto elementType = convertType(memRefType.getElementType());
+    if (!elementType)
       return Type();
-    }
 
     auto elementSize = getTypeNumBytes(elementType);
-    if (!elementSize) {
+    if (!elementSize)
       return Type();
-    }
+
     // TODO(ravishankarm) : Handle dynamic shapes.
     if (memRefType.hasStaticShape()) {
       auto arraySize = getTypeNumBytes(memRefType);
-      if (!arraySize) {
+      if (!arraySize)
         return Type();
-      }
+
       auto arrayType = spirv::ArrayType::get(
           elementType, arraySize.getValue() / elementSize.getValue(),
           elementSize.getValue());
@@ -142,34 +138,31 @@ static Type convertStdType(Type type) {
       return spirv::PointerType::get(structType,
                                      spirv::StorageClass::StorageBuffer);
     }
-  }
-
-  if (auto tensorType = type.dyn_cast<TensorType>()) {
+    return Type();
+  });
+  addConversion([this](TensorType tensorType) -> Type {
     // TODO(ravishankarm) : Handle dynamic shapes.
-    if (!tensorType.hasStaticShape()) {
+    if (!tensorType.hasStaticShape())
       return Type();
-    }
-    auto elementType = convertStdType(tensorType.getElementType());
-    if (!elementType) {
+
+    auto elementType = convertType(tensorType.getElementType());
+    if (!elementType)
       return Type();
-    }
+
     auto elementSize = getTypeNumBytes(elementType);
-    if (!elementSize) {
+    if (!elementSize)
       return Type();
-    }
+
     auto tensorSize = getTypeNumBytes(tensorType);
-    if (!tensorSize) {
+    if (!tensorSize)
       return Type();
-    }
+
     return spirv::ArrayType::get(elementType,
                                  tensorSize.getValue() / elementSize.getValue(),
                                  elementSize.getValue());
-  }
-  return Type();
+  });
 }
 
-Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }
-
 //===----------------------------------------------------------------------===//
 // FuncOp Conversion Patterns
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 40dd16db02cf..8e1a9cc942bd 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1646,13 +1646,26 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 /// This hooks allows for converting a type.
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) {
-  if (auto newT = convertType(t)) {
-    results.push_back(newT);
-    return success();
-  }
+  // Walk the added converters in reverse order to apply the most recently
+  // registered first.
+  for (ConversionCallbackFn &converter : llvm::reverse(conversions))
+    if (Optional<LogicalResult> result = converter(t, results))
+      return *result;
   return failure();
 }
 
+/// This hook simplifies defining 1-1 type conversions. This function returns
+/// the type to convert to on success, and a null type on failure.
+Type TypeConverter::convertType(Type t) {
+  // Use the multi-type result version to convert the type.
+  SmallVector<Type, 1> results;
+  if (failed(convertType(t, results)))
+    return nullptr;
+
+  // Check to ensure that only one type was produced.
+  return results.size() == 1 ? results.front() : nullptr;
+}
+
 /// Convert the given set of types, filling 'results' as necessary. This
 /// returns failure if the conversion of any of the types fails, success
 /// otherwise.

diff  --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index d975d6719170..534e86f85a63 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -305,8 +305,9 @@ struct TestNonRootReplacement : public RewritePattern {
 namespace {
 struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
+  TestTypeConverter() { addConversion(convertType); }
 
-  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
+  static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
     // Drop I16 types.
     if (t.isInteger(16))
       return success();


        


More information about the Mlir-commits mailing list