[Mlir-commits] [mlir] f4cd367 - [mlir][scf] refactor scf structuralOpConversion to better support 1:N type conversion
Peiming Liu
llvmlistbot at llvm.org
Wed Nov 2 09:45:47 PDT 2022
Author: Peiming Liu
Date: 2022-11-02T16:45:39Z
New Revision: f4cd3674ea2e055a770fa993cbfd2356c38fc545
URL: https://github.com/llvm/llvm-project/commit/f4cd3674ea2e055a770fa993cbfd2356c38fc545
DIFF: https://github.com/llvm/llvm-project/commit/f4cd3674ea2e055a770fa993cbfd2356c38fc545.diff
LOG: [mlir][scf] refactor scf structuralOpConversion to better support 1:N type conversion
This patch moves the 1:N type mapping into its own classes to allow better code reuse in D137100.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D137099
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c4c219617b782..a441b6c80b75b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -32,22 +32,82 @@ static void unpackUnrealizedConversionCast(Value v,
unpacked.push_back(v);
}
-class ConvertForOpTypes : public OpConversionPattern<ForOp> {
+// CRTP
+// A base class that takes care of 1:N type conversion, which maps the converted
+// op results (computed by the derived class) and materializes 1:N conversion.
+template <typename SourceOp, typename ConcretePattern>
+class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using OpConversionPattern<SourceOp>::typeConverter;
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+
+ //
+ // Derived classes should provide the following method which performs the
+ // actual conversion. It should return llvm::None upon conversion failure and
+ // return the converted operation upon success.
+ //
+ // Optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
+ // ConversionPatternRewriter &rewriter,
+ // TypeRange dstTypes) const;
+
LogicalResult
- matchAndRewrite(ForOp op, OpAdaptor adaptor,
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type> newResultTypes;
+ SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, newResultTypes)))
- return rewriter.notifyMatchFailure(op, "could not convert result");
- offsets.push_back(newResultTypes.size());
+ if (failed(typeConverter->convertTypes(type, dstTypes)))
+ return rewriter.notifyMatchFailure(op, "could not convert result type");
+ offsets.push_back(dstTypes.size());
}
+ // Calls the actual converter implementation to convert the operation.
+ Optional<SourceOp> newOp =
+ static_cast<const ConcretePattern *>(this)->convertSourceOp(
+ op, adaptor, rewriter, dstTypes);
+
+ if (!newOp)
+ return rewriter.notifyMatchFailure(op, "could not convert operation");
+
+ // Packs the return value.
+ SmallVector<Value> packedRets;
+ for (unsigned i = 1, e = offsets.size(); i < e; i++) {
+ unsigned start = offsets[i - 1], end = offsets[i];
+ unsigned len = end - start;
+ ValueRange mappedValue = newOp->getResults().slice(start, len);
+ if (len != 1) {
+ // 1 : N type conversion.
+ Type origType = op.getResultTypes()[i - 1];
+ Value mat = typeConverter->materializeSourceConversion(
+ rewriter, op.getLoc(), origType, mappedValue);
+ if (!mat) {
+ return rewriter.notifyMatchFailure(
+ op, "Failed to materialize 1:N type conversion");
+ }
+ packedRets.push_back(mat);
+ } else {
+ // 1 : 1 type conversion.
+ packedRets.push_back(mappedValue.front());
+ }
+ }
+
+ rewriter.replaceOp(op, packedRets);
+ return success();
+ }
+};
+
+class ConvertForOpTypes
+ : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
+public:
+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+
+ // The callback required by CRTP.
+ Optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
//
// This is a little bit tricky. We have two concerns here:
@@ -67,15 +127,15 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
// convertRegionTypes already takes care of 1:N conversion.
if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
- return failure();
+ return llvm::None;
// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs())
unpackUnrealizedConversionCast(arg, flatArgs);
- // We can not do clone as the number of result types after conversion might
- // be
diff erent.
+ // We can not do clone as the number of result types after conversion
+ // might be
diff erent.
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
adaptor.getUpperBound(),
adaptor.getStep(), flatArgs);
@@ -89,29 +149,7 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
- // Pack the return value.
- SmallVector<Value, 6> packedRets;
- for (unsigned i = 1, e = offsets.size(); i < e; i++) {
- unsigned start = offsets[i - 1], end = offsets[i];
- unsigned len = end - start;
- ValueRange mappedValue = newOp.getResults().slice(start, len);
- if (len != 1) {
- // 1 : N type conversion.
- Type origType = op.getResultTypes()[i - 1];
- Value mat = typeConverter->materializeSourceConversion(
- rewriter, op.getLoc(), origType, mappedValue);
- if (!mat)
- return rewriter.notifyMatchFailure(
- op, "Failed to materialize 1:N type conversion");
- packedRets.push_back(mat);
- } else {
- // 1 : 1 type conversion.
- packedRets.push_back(mappedValue.front());
- }
- }
-
- rewriter.replaceOp(op, packedRets);
- return success();
+ return newOp;
}
};
} // namespace
More information about the Mlir-commits
mailing list