[Mlir-commits] [mlir] 0d7ff22 - [mlir] Refactor TypeConverter to add conversions without inheritance
River Riddle
llvmlistbot at llvm.org
Tue Feb 18 16:18:48 PST 2020
Author: River Riddle
Date: 2020-02-18T16:17:48-08:00
New Revision: 0d7ff220ed0eeffda3839546ac1f30e98f60128e
URL: https://github.com/llvm/llvm-project/commit/0d7ff220ed0eeffda3839546ac1f30e98f60128e
DIFF: https://github.com/llvm/llvm-project/commit/0d7ff220ed0eeffda3839546ac1f30e98f60128e.diff
LOG: [mlir] Refactor TypeConverter to add conversions without inheritance
Summary:
This revision refactors the TypeConverter class to not use inheritance to add type conversions. It instead moves to a registration based system, where conversion callbacks are added to the converter with `addConversion`. This method takes a conversion callback, which must be convertible to any of the following forms(where `T` is a class derived from `Type`:
* Optional<Type> (T type)
- This form represents a 1-1 type conversion. It should return nullptr
or `llvm::None` to signify failure. If `llvm::None` is returned, the
converter is allowed to try another conversion function to perform
the conversion.
* Optional<LogicalResult>(T type, SmallVectorImpl<Type> &results)
- This form represents a 1-N type conversion. It should return
`failure` or `llvm::None` to signify a failed conversion. If the new
set of types is empty, the type is removed and any usages of the
existing value are expected to be removed during conversion. If
`llvm::None` is returned, the converter is allowed to try another
conversion function to perform the conversion.
When attempting to convert a type, the TypeConverter walks each of the registered converters starting with the one registered most recently.
Differential Revision: https://reviews.llvm.org/D74584
Added:
Modified:
mlir/docs/DialectConversion.md
mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/include/mlir/Support/STLExtras.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index d0643947ed20..8af0e4fb0b25 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -196,15 +196,26 @@ convert types. Several of these hooks are detailed below:
```c++
class TypeConverter {
public:
- /// This hook allows for converting a type. This function should return
- /// failure if no valid conversion exists, success otherwise. If the new set
- /// of types is empty, the type is removed and any usages of the existing
- /// value are expected to be removed during conversion.
- virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
-
- /// This hook simplifies defining 1-1 type conversions. This function returns
- /// the type to convert to on success, and a null type on failure.
- virtual Type convertType(Type t);
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms(where `T` is a class derived from `Type`:
+ /// * Optional<Type>(T)
+ /// - This form represents a 1-1 type conversion. It should return nullptr
+ /// or `llvm::None` to signify failure. If `llvm::None` is returned, the
+ /// converter is allowed to try another conversion function to perform
+ /// the conversion.
+ /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// - This form represents a 1-N type conversion. It should return
+ /// `failure` or `llvm::None` to signify a failed conversion. If the new
+ /// set of types is empty, the type is removed and any usages of the
+ /// existing value are expected to be removed during conversion. If
+ /// `llvm::None` is returned, the converter is allowed to try another
+ /// conversion function to perform the conversion.
+ ///
+ /// When attempting to convert a type, e.g. via `convertType`, the
+ /// `TypeConverter` will invoke each of the converters starting with the one
+ /// most recently registered.
+ template <typename ConversionFnT>
+ void addConversion(ConversionFnT &&callback);
/// This hook allows for materializing a conversion from a set of types into
/// one result type by generating a cast operation of some kind. The generated
diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
index e2e5bd2880a6..4124f3f0e3b2 100644
--- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
+++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
@@ -16,14 +16,8 @@ class MLIRContext;
class ModuleOp;
template <typename T> class OpPassBase;
-class LinalgTypeConverter : public LLVMTypeConverter {
-public:
- using LLVMTypeConverter::LLVMTypeConverter;
- Type convertType(Type t) override;
-};
-
/// Populate the given list with patterns that convert from Linalg to LLVM.
-void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
+void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx);
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 81e3ea8f1735..db3a948f000f 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -78,10 +78,6 @@ class LLVMTypeConverter : public TypeConverter {
LLVMTypeConverter(MLIRContext *ctx,
const LLVMTypeConverterCustomization &custom);
- /// Convert types to LLVM IR. This calls `convertAdditionalType` to convert
- /// non-standard or non-builtin types.
- Type convertType(Type t) override;
-
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
@@ -129,8 +125,6 @@ class LLVMTypeConverter : public TypeConverter {
LLVM::LLVMDialect *llvmDialect;
private:
- Type convertStandardType(Type type);
-
// Convert a function type. The arguments and results are converted one by
// one. Additionally, if the function returns more than one value, pack the
// results into an LLVM IR structure type so that the converted function type
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 2a70173ec37c..85b42eeea291 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -27,10 +27,7 @@ namespace mlir {
/// pointers to structs.
class SPIRVTypeConverter : public TypeConverter {
public:
- using TypeConverter::TypeConverter;
-
- /// Converts the given standard `type` to SPIR-V correspondence.
- Type convertType(Type type) override;
+ SPIRVTypeConverter();
/// Gets the SPIR-V correspondence for the standard index type.
static Type getIndexType(MLIRContext *context);
diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 0e6b95ac5bdc..14336aad6a25 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -376,6 +376,10 @@ struct FunctionTraits<ReturnType (*)(Args...), false> {
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
+/// Overload for non-class function type references.
+template <typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType (&)(Args...), false>
+ : public FunctionTraits<ReturnType (*)(Args...)> {};
} // end namespace mlir
// Allow tuples to be usable as DenseMap keys.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index cd148f2fd2ea..664005402ead 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -91,15 +91,37 @@ class TypeConverter {
SmallVector<Type, 4> argTypes;
};
- /// This hook allows for converting a type. This function should return
- /// failure if no valid conversion exists, success otherwise. If the new set
- /// of types is empty, the type is removed and any usages of the existing
- /// value are expected to be removed during conversion.
- virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms(where `T` is a class derived from `Type`:
+ /// * Optional<Type>(T)
+ /// - This form represents a 1-1 type conversion. It should return nullptr
+ /// or `llvm::None` to signify failure. If `llvm::None` is returned, the
+ /// converter is allowed to try another conversion function to perform
+ /// the conversion.
+ /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// - This form represents a 1-N type conversion. It should return
+ /// `failure` or `llvm::None` to signify a failed conversion. If the new
+ /// set of types is empty, the type is removed and any usages of the
+ /// existing value are expected to be removed during conversion. If
+ /// `llvm::None` is returned, the converter is allowed to try another
+ /// conversion function to perform the conversion.
+ /// Note: When attempting to convert a type, e.g. via 'convertType', the
+ /// mostly recently added conversions will be invoked first.
+ template <typename FnT,
+ typename T = typename FunctionTraits<FnT>::template arg_t<0>>
+ void addConversion(FnT &&callback) {
+ registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
+ }
+
+ /// Convert the given type. This function should return failure if no valid
+ /// conversion exists, success otherwise. If the new set of types is empty,
+ /// the type is removed and any usages of the existing value are expected to
+ /// be removed during conversion.
+ LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
- virtual Type convertType(Type t) { return t; }
+ Type convertType(Type t);
/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
@@ -138,6 +160,50 @@ class TypeConverter {
Location loc) {
llvm_unreachable("expected 'materializeConversion' to be overridden");
}
+
+private:
+ /// The signature of the callback used to convert a type. If the new set of
+ /// types is empty, the type is removed and any usages of the existing value
+ /// are expected to be removed during conversion.
+ using ConversionCallbackFn =
+ std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
+
+ /// Generate a wrapper for the given callback. This allows for accepting
+ ///
diff erent callback forms, that all compose into a single version.
+ /// With callback of form: `Optional<Type>(T)`
+ template <typename T, typename FnT>
+ std::enable_if_t<is_invocable<FnT, T>::value, ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ return wrapCallback<T>([=](T type, SmallVectorImpl<Type> &results) {
+ if (Optional<Type> resultOpt = callback(type)) {
+ bool wasSuccess = static_cast<bool>(resultOpt.getValue());
+ if (wasSuccess)
+ results.push_back(resultOpt.getValue());
+ return Optional<LogicalResult>(success(wasSuccess));
+ }
+ return Optional<LogicalResult>();
+ });
+ }
+ /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
+ template <typename T, typename FnT>
+ std::enable_if_t<!is_invocable<FnT, T>::value, ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ return [=](Type type,
+ SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
+ T derivedType = type.dyn_cast<T>();
+ if (!derivedType)
+ return llvm::None;
+ return callback(derivedType, results);
+ };
+ }
+
+ /// Register a type conversion.
+ void registerConversion(ConversionCallbackFn callback) {
+ conversions.emplace_back(std::move(callback));
+ }
+
+ /// The set of registered conversion functions.
+ SmallVector<ConversionCallbackFn, 4> conversions;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2164c038c124..df0a0535e7ff 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -28,25 +28,6 @@ using namespace mlir;
namespace {
-/// Derived type converter for GPU to NVVM lowering. The GPU dialect uses memory
-/// space 5 for private memory attributions, but NVVM represents private
-/// memory allocations as local `alloca`s in the default address space. This
-/// converter drops the private memory space to support the use case above.
-class NVVMTypeConverter : public LLVMTypeConverter {
-public:
- using LLVMTypeConverter::LLVMTypeConverter;
-
- Type convertType(Type type) override {
- auto memref = type.dyn_cast<MemRefType>();
- if (memref &&
- memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
- type = MemRefType::Builder(memref).setMemorySpace(0);
- }
-
- return LLVMTypeConverter::convertType(type);
- }
-};
-
/// Converts all_reduce op to LLVM/NVVM ops.
struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
using AccumulatorFactory =
@@ -683,8 +664,19 @@ class LowerGpuOpsToNVVMOpsPass
public:
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
+
+ /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
+ /// space 5 for private memory attributions, but NVVM represents private
+ /// memory allocations as local `alloca`s in the default address space. This
+ /// converter drops the private memory space to support the use case above.
+ LLVMTypeConverter converter(m.getContext());
+ converter.addConversion([&](MemRefType type) -> Optional<Type> {
+ if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
+ return llvm::None;
+ return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
+ });
+
OwningRewritePatternList patterns;
- NVVMTypeConverter converter(m.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToNVVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index bd83aaf9fef7..72806aa4b7fc 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -73,30 +73,19 @@ static LLVMType getPtrToElementType(T containerType,
.getPointerTo();
}
-// Convert the given type to the LLVM IR Dialect type. The following
-// conversions are supported:
-// - an Index type is converted into an LLVM integer type with pointer
-// bitwidth (analogous to intptr_t in C);
-// - an Integer type is converted into an LLVM integer type of the same width;
-// - an F32 type is converted into an LLVM float type
-// - a Buffer, Range or View is converted into an LLVM structure type
-// containing the respective dynamic values.
-static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
+/// Convert the given range descriptor type to the LLVMIR dialect.
+/// Range descriptor contains the range bounds and the step as 64-bit integers.
+///
+/// struct {
+/// int64_t min;
+/// int64_t max;
+/// int64_t step;
+/// };
+static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
- auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+ auto int64Ty = converter.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>();
-
- // Range descriptor contains the range bounds and the step as 64-bit integers.
- //
- // struct {
- // int64_t min;
- // int64_t max;
- // int64_t step;
- // };
- if (t.isa<RangeType>())
- return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
-
- return Type();
+ return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
}
namespace {
@@ -146,7 +135,7 @@ class RangeOpConversion : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
- convertLinalgType(rangeOp.getResult().getType(), typeConverter);
+ convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -418,12 +407,6 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
-Type LinalgTypeConverter::convertType(Type t) {
- if (auto result = LLVMTypeConverter::convertType(t))
- return result;
- return convertLinalgType(t, *this);
-}
-
namespace {
// LinalgOpConversion<LinalgOp> creates a new call to the
@@ -555,10 +538,14 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
- LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
TransposeOpConversion, YieldOpConversion>(ctx, converter);
+
+ // Populate the type conversions for the linalg types.
+ converter.addConversion(
+ [&](RangeType type) { return convertRangeType(type, converter); });
}
namespace {
@@ -572,7 +559,7 @@ void ConvertLinalgToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
- LinalgTypeConverter converter(&getContext());
+ LLVMTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 2378b6a355cf..88ca986874e4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -130,6 +129,19 @@ LLVMTypeConverter::LLVMTypeConverter(
customizations(customs) {
assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
+
+ // Register conversions for the standard types.
+ addConversion([&](FloatType type) { return convertFloatType(type); });
+ addConversion([&](FunctionType type) { return convertFunctionType(type); });
+ addConversion([&](IndexType type) { return convertIndexType(type); });
+ addConversion([&](IntegerType type) { return convertIntegerType(type); });
+ addConversion([&](MemRefType type) { return convertMemRefType(type); });
+ addConversion(
+ [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
+ addConversion([&](VectorType type) { return convertVectorType(type); });
+
+ // LLVMType is legal, so add a pass-through conversion.
+ addConversion([](LLVM::LLVMType type) { return type; });
}
/// Get the LLVM context.
@@ -359,22 +371,6 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
return vectorType;
}
-// Dispatch based on the actual type. Return null type on error.
-Type LLVMTypeConverter::convertStandardType(Type t) {
- return TypeSwitch<Type, Type>(t)
- .Case([&](FloatType type) { return convertFloatType(type); })
- .Case([&](FunctionType type) { return convertFunctionType(type); })
- .Case([&](IndexType type) { return convertIndexType(type); })
- .Case([&](IntegerType type) { return convertIntegerType(type); })
- .Case([&](MemRefType type) { return convertMemRefType(type); })
- .Case([&](UnrankedMemRefType type) {
- return convertUnrankedMemRefType(type);
- })
- .Case([&](VectorType type) { return convertVectorType(type); })
- .Case([](LLVM::LLVMType type) { return type; })
- .Default([](Type) { return Type(); });
-}
-
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter_,
@@ -2658,9 +2654,6 @@ void mlir::populateStdToLLVMBarePtrConversionPatterns(
populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
}
-// Convert types using the stored LLVM IR module.
-Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
-
// Create an LLVM IR structure type if there is more than one result.
Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 3d11e3c92cd9..773066148e20 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -98,41 +98,37 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}
-static Type convertStdType(Type type) {
- // If the type is already valid in SPIR-V, directly return.
- if (spirv::SPIRVDialect::isValidType(type)) {
- return type;
- }
-
- if (auto indexType = type.dyn_cast<IndexType>()) {
- return SPIRVTypeConverter::getIndexType(type.getContext());
- }
-
- if (auto memRefType = type.dyn_cast<MemRefType>()) {
+SPIRVTypeConverter::SPIRVTypeConverter() {
+ addConversion([](Type type) -> Optional<Type> {
+ // If the type is already valid in SPIR-V, directly return.
+ return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
+ });
+ addConversion([](IndexType indexType) {
+ return SPIRVTypeConverter::getIndexType(indexType.getContext());
+ });
+ addConversion([this](MemRefType memRefType) -> Type {
// TODO(ravishankarm): For now only support default memory space. The memory
// space description is not set is stone within MLIR, i.e. it depends on the
// context it is being used. To map this to SPIR-V storage classes, we
// should rely on the ABI attributes, and not on the memory space. This is
// still evolving, and needs to be revisited when there is more clarity.
- if (memRefType.getMemorySpace()) {
+ if (memRefType.getMemorySpace())
return Type();
- }
- auto elementType = convertStdType(memRefType.getElementType());
- if (!elementType) {
+ auto elementType = convertType(memRefType.getElementType());
+ if (!elementType)
return Type();
- }
auto elementSize = getTypeNumBytes(elementType);
- if (!elementSize) {
+ if (!elementSize)
return Type();
- }
+
// TODO(ravishankarm) : Handle dynamic shapes.
if (memRefType.hasStaticShape()) {
auto arraySize = getTypeNumBytes(memRefType);
- if (!arraySize) {
+ if (!arraySize)
return Type();
- }
+
auto arrayType = spirv::ArrayType::get(
elementType, arraySize.getValue() / elementSize.getValue(),
elementSize.getValue());
@@ -142,34 +138,31 @@ static Type convertStdType(Type type) {
return spirv::PointerType::get(structType,
spirv::StorageClass::StorageBuffer);
}
- }
-
- if (auto tensorType = type.dyn_cast<TensorType>()) {
+ return Type();
+ });
+ addConversion([this](TensorType tensorType) -> Type {
// TODO(ravishankarm) : Handle dynamic shapes.
- if (!tensorType.hasStaticShape()) {
+ if (!tensorType.hasStaticShape())
return Type();
- }
- auto elementType = convertStdType(tensorType.getElementType());
- if (!elementType) {
+
+ auto elementType = convertType(tensorType.getElementType());
+ if (!elementType)
return Type();
- }
+
auto elementSize = getTypeNumBytes(elementType);
- if (!elementSize) {
+ if (!elementSize)
return Type();
- }
+
auto tensorSize = getTypeNumBytes(tensorType);
- if (!tensorSize) {
+ if (!tensorSize)
return Type();
- }
+
return spirv::ArrayType::get(elementType,
tensorSize.getValue() / elementSize.getValue(),
elementSize.getValue());
- }
- return Type();
+ });
}
-Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }
-
//===----------------------------------------------------------------------===//
// FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 40dd16db02cf..8e1a9cc942bd 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1646,13 +1646,26 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
/// This hooks allows for converting a type.
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) {
- if (auto newT = convertType(t)) {
- results.push_back(newT);
- return success();
- }
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (ConversionCallbackFn &converter : llvm::reverse(conversions))
+ if (Optional<LogicalResult> result = converter(t, results))
+ return *result;
return failure();
}
+/// This hook simplifies defining 1-1 type conversions. This function returns
+/// the type to convert to on success, and a null type on failure.
+Type TypeConverter::convertType(Type t) {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(t, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// otherwise.
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index d975d6719170..534e86f85a63 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -305,8 +305,9 @@ struct TestNonRootReplacement : public RewritePattern {
namespace {
struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
+ TestTypeConverter() { addConversion(convertType); }
- LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
+ static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Drop I16 types.
if (t.isInteger(16))
return success();
More information about the Mlir-commits
mailing list