[Mlir-commits] [mlir] 0ebb050 - [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold
Matthias Gehre
llvmlistbot at llvm.org
Mon Jul 17 01:14:43 PDT 2023
Author: Matthias Gehre
Date: 2023-07-17T10:14:37+02:00
New Revision: 0ebb0503113e97eb14dc679f06bdc1d2e7296d54
URL: https://github.com/llvm/llvm-project/commit/0ebb0503113e97eb14dc679f06bdc1d2e7296d54
DIFF: https://github.com/llvm/llvm-project/commit/0ebb0503113e97eb14dc679f06bdc1d2e7296d54.diff
LOG: [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold
reshape(reshape(x)) -> reshape(x) can be directly written as a fold instead of a canonicalization,
to help other passes cleanup while they work.
This initially broke ReshapeConverterExpand/Collapse, which relies on creating foldable reshapes and a carefully crafted
benefit priority of patterns.
I turned this into a single pattern on reshapes, which does expand and/or collapse as needed in one go.
Differential Revision: https://reviews.llvm.org/D155266
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e5b4e664202f7d..7a16b37f0ca417 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1480,7 +1480,6 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
No data conversion happens during a reshape operation.
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 1a9f48e311749a..f51ada8d08b5ed 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -129,81 +129,74 @@ static bool createReassociationMapsForCollapse(
}
namespace {
-class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
-public:
- using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
- ShapedType resultTy = cast<ShapedType>(reshape.getType());
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && resultTy.getRank() != 1) {
- return rewriter.notifyMatchFailure(
- reshape, "Cannot collapse dynamic dims to more than one dimension");
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
- resultTy.getShape(),
- reassociationMap, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape,
- "tosa.reshape Attempting to collapse into an incompatible shape");
- }
+Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
+ ShapedType resultTy, Value operand) {
+ ShapedType operandTy = cast<ShapedType>(operand.getType());
+ if (resultTy == operandTy)
+ return operand;
+
+ bool isDynamic = !operandTy.hasStaticShape();
+
+ if (isDynamic && resultTy.getRank() != 1) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "Cannot collapse dynamic dims to more than one dimension");
+ return {};
+ }
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot collapse into given shape");
- }
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+ resultTy.getShape(),
+ reassociationMap, isDynamic)) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "tosa.reshape Attempting to collapse into an incompatible shape");
+ return {};
+ }
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
- return success();
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic)) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "tosa.reshape Cannot collapse into given shape");
+ return {};
}
-};
+ return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
+ reassociationMap);
+}
-class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
-public:
- using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
+ ShapedType resultTy, Value operand) {
+ ShapedType operandTy = cast<ShapedType>(operand.getType());
+ if (resultTy == operandTy)
+ return operand;
- LogicalResult
- matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
- ShapedType resultTy = cast<ShapedType>(reshape.getType());
- bool isDynamic = !operandTy.hasStaticShape();
+ bool isDynamic = !operandTy.hasStaticShape();
- if (isDynamic && operandTy.getRank() != 1) {
- return rewriter.notifyMatchFailure(
- reshape, "Cannot expand dynamic dims from more than one dimension");
- }
+ if (isDynamic && operandTy.getRank() != 1) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "Cannot expand dynamic dims from more than one dimension");
+ return {};
+ }
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
- operandTy.getShape(),
- reassociationMap, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape,
- "tosa.reshape Attempting to expand into an incompatible shape");
- }
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+ operandTy.getShape(),
+ reassociationMap, isDynamic)) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "tosa.reshape Attempting to expand into an incompatible shape");
+ return {};
+ }
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic) ||
- intermediateShape != operandTy.getShape()) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot expand into given shape");
- }
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
- return success();
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic) ||
+ intermediateShape != operandTy.getShape()) {
+ (void)rewriter.notifyMatchFailure(
+ loc, "tosa.reshape Cannot expand into given shape");
+ return {};
}
-};
+ return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
+ reassociationMap);
+}
class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
@@ -224,17 +217,19 @@ class ReshapeConverterCollapseExpand
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
+ auto intermediateTy = RankedTensorType::get(
+ intermediateShape, reshape.getType().getElementType());
- Value collapse = rewriter.create<tosa::ReshapeOp>(
- reshape.getLoc(),
- RankedTensorType::get(intermediateShape,
- reshape.getType().getElementType()),
- adaptor.getInput1(), rewriter.getDenseI64ArrayAttr(intermediateShape));
- Value expand = rewriter.create<tosa::ReshapeOp>(
- reshape.getLoc(), resultTy, collapse,
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
- rewriter.replaceOp(reshape, expand);
+ Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
+ adaptor.getInput1());
+ if (!collapse)
+ return failure();
+ Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
+ if (!expand)
+ return failure();
+
+ rewriter.replaceOp(reshape, expand);
return success();
}
};
@@ -420,10 +415,6 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
patterns->getContext());
- patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
- /*benefit=*/100);
- patterns->add<ReshapeConverterExpand>(patterns->getContext(),
- /*benefit=*/200);
- patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
- /*benefit=*/300);
+
+ patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8cefa64bc4c1fe..152b8857393bba 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -62,31 +62,6 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConcatOptimization>(context);
}
-struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
- using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::ReshapeOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getInput1();
- Operation *definingOp = input.getDefiningOp();
- if (!definingOp)
- return failure();
-
- if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
- return success();
- }
-
- return failure();
- }
-};
-
-void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ReshapeReshapeOptimization>(context);
-}
-
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
@@ -820,25 +795,32 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy)
return getInput1();
- // Constants must have static shape.
- if (!outputTy.hasStaticShape())
- return {};
+ // reshape(reshape(x)) -> reshape(x)
+ if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
+ getInput1().getDefiningOp())) {
+ getInput1Mutable().assign(reshapeOp.getInput1());
+ return getResult();
+ }
- auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- if (!operand)
- return {};
+ // reshape(const(x)) -> const(reshape-attr(x))
+ if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+ // Constants must have static shape.
+ if (!outputTy.hasStaticShape())
+ return {};
- // Okay to duplicate splat constants.
- if (operand.isSplat()) {
- return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
- }
+ // Okay to duplicate splat constants.
+ if (operand.isSplat())
+ return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
- // Don't duplicate other constants.
- if (!getInput1().hasOneUse())
- return {};
+ // Don't duplicate other constants.
+ if (!getInput1().hasOneUse())
+ return {};
- return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ return operand.reshape(
+ llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ }
+
+ return {};
}
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
More information about the Mlir-commits
mailing list