[Mlir-commits] [mlir] 5c5dafc - [mlir] support materialization for 1-1 type conversions
Alex Zinenko
llvmlistbot at llvm.org
Tue Jun 2 04:48:43 PDT 2020
Author: Alex Zinenko
Date: 2020-06-02T13:48:33+02:00
New Revision: 5c5dafc534ac80aad978f4092ff842457aab6d07
URL: https://github.com/llvm/llvm-project/commit/5c5dafc534ac80aad978f4092ff842457aab6d07
DIFF: https://github.com/llvm/llvm-project/commit/5c5dafc534ac80aad978f4092ff842457aab6d07.diff
LOG: [mlir] support materialization for 1-1 type conversions
Dialect conversion infrastructure supports 1->N type conversions by requiring
individual conversions to provide facilities to generate operations
retrofitting N values into 1 of the original type when N > 1. This
functionality can also be used to materialize explicit "cast"-like operations,
but it did not support 1->1 type conversions until now. Modify TypeConverter to
support materialization of cast operations for 1-1 conversions.
This also makes materialization specification more extensible following the
same pattern as type conversions. Instead of overloading a virtual function,
users or subclasses of TypeConversion can now register type-specific
materialization callbacks that will be called in order for the given type.
Differential Revision: https://reviews.llvm.org/D79729
Added:
Modified:
mlir/docs/DialectConversion.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 1f9ec14b6a96..7995099636e9 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -217,16 +217,20 @@ class TypeConverter {
template <typename ConversionFnT>
void addConversion(ConversionFnT &&callback);
- /// This hook allows for materializing a conversion from a set of types into
- /// one result type by generating a cast operation of some kind. The generated
- /// operation should produce one result, of 'resultType', with the provided
- /// 'inputs' as operands. This hook must be overridden when a type conversion
+ /// Register a materialization function, which must be convertibe to the
+ /// following form
+ /// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
+ /// where `T` is any subclass of `Type`. This function is responsible for
+ /// creating an operation, using the PatternRewriter and Location provided,
+ /// that "casts" a range of values into a single value of the given type `T`.
+ /// It must return a Value of the converted type on success, an `llvm::None`
+ /// if it failed but other materialization can be attempted, and `nullptr` on
+ /// unrecoverable failure. It will only be called for (sub)types of `T`.
+ /// Materialization functions must be provided when a type conversion
/// results in more than one type, or if a type conversion may persist after
/// the conversion has finished.
- virtual Operation *materializeConversion(PatternRewriter &rewriter,
- Type resultType,
- ArrayRef<Value> inputs,
- Location loc);
+ template <typename FnT>
+ void addMaterialization(FnT &&callback);
};
```
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index c241de6ff6fe..a5c14139668b 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -122,11 +122,6 @@ class LLVMTypeConverter : public TypeConverter {
/// pointers to memref descriptors for arguments.
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
- /// Creates descriptor structs from individual values constituting them.
- Operation *materializeConversion(PatternRewriter &rewriter, Type type,
- ArrayRef<Value> values,
- Location loc) override;
-
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
LLVM::LLVMType getIndexType();
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 31b5e04c9dbd..029c38f6a1c7 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -113,6 +113,25 @@ class TypeConverter {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
+ /// Register a materialization function, which must be convertible to the
+ /// following form:
+ /// `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
+ /// where `T` is any subclass of `Type`. This function is responsible for
+ /// creating an operation, using the PatternRewriter and Location provided,
+ /// that "casts" a range of values into a single value of the given type `T`.
+ /// It must return a Value of the converted type on success, an `llvm::None`
+ /// if it failed but other materialization can be attempted, and `nullptr` on
+ /// unrecoverable failure. It will only be called for (sub)types of `T`.
+ /// Materialization functions must be provided when a type conversion
+ /// results in more than one type, or if a type conversion may persist after
+ /// the conversion has finished.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+ void addMaterialization(FnT &&callback) {
+ registerMaterialization(
+ wrapMaterialization<T>(std::forward<FnT>(callback)));
+ }
+
/// Convert the given type. This function should return failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
@@ -148,18 +167,10 @@ class TypeConverter {
/// valid conversion for the signature on success, None otherwise.
Optional<SignatureConversion> convertBlockSignature(Block *block);
- /// This hook allows for materializing a conversion from a set of types into
- /// one result type by generating a cast operation of some kind. The generated
- /// operation should produce one result, of 'resultType', with the provided
- /// 'inputs' as operands. This hook must be overridden when a type conversion
- /// results in more than one type, or if a type conversion may persist after
- /// the conversion has finished.
- virtual Operation *materializeConversion(PatternRewriter &rewriter,
- Type resultType,
- ArrayRef<Value> inputs,
- Location loc) {
- llvm_unreachable("expected 'materializeConversion' to be overridden");
- }
+ /// Materialize a conversion from a set of types into one result type by
+ /// generating a cast operation of some kind.
+ Value materializeConversion(PatternRewriter &rewriter, Location loc,
+ Type resultType, ValueRange inputs);
private:
/// The signature of the callback used to convert a type. If the new set of
@@ -168,6 +179,9 @@ class TypeConverter {
using ConversionCallbackFn =
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
+ using MaterializationCallbackFn = std::function<Optional<Value>(
+ PatternRewriter &, Type, ValueRange, Location)>;
+
/// Generate a wrapper for the given callback. This allows for accepting
///
diff erent callback forms, that all compose into a single version.
/// With callback of form: `Optional<Type>(T)`
@@ -204,8 +218,30 @@ class TypeConverter {
conversions.emplace_back(std::move(callback));
}
+ /// Generate a wrapper for the given materialization callback. The callback
+ /// may take any subclass of `Type` and the wrapper will check for the target
+ /// type to be of the expected class before calling the callback.
+ template <typename T, typename FnT>
+ MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
+ return [callback = std::forward<FnT>(callback)](
+ PatternRewriter &rewriter, Type resultType, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ if (T derivedType = resultType.dyn_cast<T>())
+ return callback(rewriter, derivedType, inputs, loc);
+ return llvm::None;
+ };
+ }
+
+ /// Register a materialization.
+ void registerMaterialization(MaterializationCallbackFn &&callback) {
+ materializations.emplace_back(std::move(callback));
+ }
+
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
+
+ /// The list of registered materialization functions.
+ SmallVector<MaterializationCallbackFn, 2> materializations;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 8cc2315ddd15..4294e0024e79 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -150,6 +150,24 @@ LLVMTypeConverter::LLVMTypeConverter(
// LLVMType is legal, so add a pass-through conversion.
addConversion([](LLVM::LLVMType type) { return type; });
+
+ // Materialization for memrefs creates descriptor structs from individual
+ // values constituting them, when descriptors are used, i.e. more than one
+ // value represents a memref.
+ addMaterialization([&](PatternRewriter &rewriter,
+ UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ if (inputs.size() == 1)
+ return llvm::None;
+ return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
+ inputs);
+ });
+ addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
+ ValueRange inputs, Location loc) -> Optional<Value> {
+ if (inputs.size() == 1)
+ return llvm::None;
+ return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
+ });
}
/// Returns the MLIR context.
@@ -297,22 +315,6 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
}
-/// Creates descriptor structs from individual values constituting them.
-Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
- Type type,
- ArrayRef<Value> values,
- Location loc) {
- if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
- return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
- unrankedMemRefType, values)
- .getDefiningOp();
-
- auto memRefType = type.dyn_cast<MemRefType>();
- assert(memRefType && "1->N conversion is only supported for memrefs");
- return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
- .getDefiningOp();
-}
-
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 0e2bf03f9639..ff1ce3739d81 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -305,27 +305,18 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
// persist in the IR after conversion.
if (!origArg.use_empty()) {
rewriter.setInsertionPointToStart(newBlock);
- auto *newOp = typeConverter->materializeConversion(
- rewriter, origArg.getType(), llvm::None, loc);
- origArg.replaceAllUsesWith(newOp->getResult(0));
+ Value newArg = typeConverter->materializeConversion(
+ rewriter, loc, origArg.getType(), llvm::None);
+ assert(newArg &&
+ "Couldn't materialize a block argument after 1->0 conversion");
+ origArg.replaceAllUsesWith(newArg);
}
continue;
}
- // If mapping is 1-1, replace the remaining uses and drop the cast
- // operation.
- // FIXME(riverriddle) This should check that the result type and operand
- // type are the same, otherwise it should force a conversion to be
- // materialized.
- if (argInfo->newArgSize == 1) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
- continue;
- }
-
- // Otherwise this is a 1->N value mapping.
+ // Otherwise this is a 1->1+ value mapping.
Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
+ assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty())
@@ -333,7 +324,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
// If all users of the cast were removed, we can drop it. Otherwise, keep
// the operation alive and let the user handle any remaining usages.
- if (castValue.use_empty())
+ if (castValue.use_empty() && castValue.getDefiningOp())
castValue.getDefiningOp()->erase();
}
}
@@ -389,22 +380,22 @@ Block *ArgConverter::applySignatureConversion(
continue;
}
- // If this is a 1->1 mapping, then map the argument directly.
- if (inputMap->size == 1) {
- mapping.map(origArg, newArgs[inputMap->inputNo]);
- info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
- continue;
- }
-
- // Otherwise, this is a 1->N mapping. Call into the provided type converter
- // to pack the new values.
+ // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
+ // to pack the new values. For 1->1 mappings, if there is no materialization
+ // provided, use the argument directly instead.
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Operation *cast = typeConverter->materializeConversion(
- rewriter, origArg.getType(), replArgs, loc);
- assert(cast->getNumResults() == 1);
- mapping.map(origArg, cast->getResult(0));
+ Value newArg;
+ if (typeConverter)
+ newArg = typeConverter->materializeConversion(
+ rewriter, loc, origArg.getType(), replArgs);
+ if (!newArg) {
+ assert(replArgs.size() == 1 &&
+ "couldn't materialize the result of 1->N conversion");
+ newArg = replArgs.front();
+ }
+ mapping.map(origArg, newArg);
info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
+ ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
// Remove the original block from the region and return the new one.
@@ -1815,6 +1806,15 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
return success();
}
+Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
+ Location loc, Type resultType,
+ ValueRange inputs) {
+ for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
+ if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
+ return result.getValue();
+ return nullptr;
+}
+
/// Create a default conversion pattern that rewrites the type signature of a
/// FuncOp.
namespace {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 98f350053c04..284c38bad7b7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -48,6 +48,13 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
"work"(%arg0) : (f32) -> ()
}
+// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43)
+func @remap_materialize_1_to_1(%arg0: i42) {
+ // CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42
+ // CHECK: "test.return"(%[[V]])
+ "test.return"(%arg0) : (i42) -> ()
+}
+
// CHECK-LABEL: func @remap_input_to_self
func @remap_input_to_self(%arg0: index) {
// CHECK-NOT: test.cast
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index df3068cc6487..dab8e196111a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -477,7 +477,11 @@ struct TestNestedOpCreationUndoRewrite
namespace {
struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
- TestTypeConverter() { addConversion(convertType); }
+ TestTypeConverter() {
+ addConversion(convertType);
+ addMaterialization(materializeCast);
+ addMaterialization(materializeOneToOneCast);
+ }
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Drop I16 types.
@@ -490,6 +494,12 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
+ // Convert I42 to I43.
+ if (t.isInteger(42)) {
+ results.push_back(IntegerType::get(43, t.getContext()));
+ return success();
+ }
+
// Split F32 into F16,F16.
if (t.isF32()) {
results.assign(2, FloatType::getF16(t.getContext()));
@@ -501,12 +511,24 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
- /// Override the hook to materialize a conversion. This is necessary because
- /// we generate 1->N type mappings.
- Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
- ArrayRef<Value> inputs,
- Location loc) override {
- return rewriter.create<TestCastOp>(loc, resultType, inputs);
+ /// Hook for materializing a conversion. This is necessary because we generate
+ /// 1->N type mappings.
+ static Optional<Value> materializeCast(PatternRewriter &rewriter,
+ Type resultType, ValueRange inputs,
+ Location loc) {
+ if (inputs.size() == 1)
+ return inputs[0];
+ return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+ }
+
+ /// Materialize the cast for one-to-one conversion from i64 to f64.
+ static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
+ IntegerType resultType,
+ ValueRange inputs,
+ Location loc) {
+ if (resultType.getWidth() == 42 && inputs.size() == 1)
+ return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+ return llvm::None;
}
};
More information about the Mlir-commits
mailing list