[Mlir-commits] [mlir] 0ebb050 - [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold

Matthias Gehre llvmlistbot at llvm.org
Mon Jul 17 01:14:43 PDT 2023


Author: Matthias Gehre
Date: 2023-07-17T10:14:37+02:00
New Revision: 0ebb0503113e97eb14dc679f06bdc1d2e7296d54

URL: https://github.com/llvm/llvm-project/commit/0ebb0503113e97eb14dc679f06bdc1d2e7296d54
DIFF: https://github.com/llvm/llvm-project/commit/0ebb0503113e97eb14dc679f06bdc1d2e7296d54.diff

LOG: [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold

reshape(reshape(x)) -> reshape(x) can be directly written as a fold instead of a canonicalization,
to help other passes cleanup while they work.

This initially broke ReshapeConverterExpand/Collapse, which relies on creating foldable reshapes and a carefully crafted
benefit priority of patterns.
I turned this into a single pattern on reshapes, which does expand and/or collapse as needed in one go.

Differential Revision: https://reviews.llvm.org/D155266

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e5b4e664202f7d..7a16b37f0ca417 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1480,7 +1480,6 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
     No data conversion happens during a reshape operation.
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 1a9f48e311749a..f51ada8d08b5ed 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -129,81 +129,74 @@ static bool createReassociationMapsForCollapse(
 }
 
 namespace {
-class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
-public:
-  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
-    ShapedType resultTy = cast<ShapedType>(reshape.getType());
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    if (isDynamic && resultTy.getRank() != 1) {
-      return rewriter.notifyMatchFailure(
-          reshape, "Cannot collapse dynamic dims to more than one dimension");
-    }
-
-    SmallVector<ReassociationExprs, 4> reassociationMap;
-    if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
-                                            resultTy.getShape(),
-                                            reassociationMap, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape,
-          "tosa.reshape Attempting to collapse into an incompatible shape");
-    }
+Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
+                     ShapedType resultTy, Value operand) {
+  ShapedType operandTy = cast<ShapedType>(operand.getType());
+  if (resultTy == operandTy)
+    return operand;
+
+  bool isDynamic = !operandTy.hasStaticShape();
+
+  if (isDynamic && resultTy.getRank() != 1) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "Cannot collapse dynamic dims to more than one dimension");
+    return {};
+  }
 
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                               intermediateShape, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot collapse into given shape");
-    }
+  SmallVector<ReassociationExprs, 4> reassociationMap;
+  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+                                          resultTy.getShape(),
+                                          reassociationMap, isDynamic)) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "tosa.reshape Attempting to collapse into an incompatible shape");
+    return {};
+  }
 
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
-    return success();
+  SmallVector<int64_t> intermediateShape;
+  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                             intermediateShape, isDynamic)) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "tosa.reshape Cannot collapse into given shape");
+    return {};
   }
-};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
+                                                  reassociationMap);
+}
 
-class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
-public:
-  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
+                   ShapedType resultTy, Value operand) {
+  ShapedType operandTy = cast<ShapedType>(operand.getType());
+  if (resultTy == operandTy)
+    return operand;
 
-  LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
-    ShapedType resultTy = cast<ShapedType>(reshape.getType());
-    bool isDynamic = !operandTy.hasStaticShape();
+  bool isDynamic = !operandTy.hasStaticShape();
 
-    if (isDynamic && operandTy.getRank() != 1) {
-      return rewriter.notifyMatchFailure(
-          reshape, "Cannot expand dynamic dims from more than one dimension");
-    }
+  if (isDynamic && operandTy.getRank() != 1) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "Cannot expand dynamic dims from more than one dimension");
+    return {};
+  }
 
-    SmallVector<ReassociationExprs, 4> reassociationMap;
-    if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
-                                            operandTy.getShape(),
-                                            reassociationMap, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape,
-          "tosa.reshape Attempting to expand into an incompatible shape");
-    }
+  SmallVector<ReassociationExprs, 4> reassociationMap;
+  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+                                          operandTy.getShape(),
+                                          reassociationMap, isDynamic)) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "tosa.reshape Attempting to expand into an incompatible shape");
+    return {};
+  }
 
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                               intermediateShape, isDynamic) ||
-        intermediateShape != operandTy.getShape()) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot expand into given shape");
-    }
-    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
-    return success();
+  SmallVector<int64_t> intermediateShape;
+  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                             intermediateShape, isDynamic) ||
+      intermediateShape != operandTy.getShape()) {
+    (void)rewriter.notifyMatchFailure(
+        loc, "tosa.reshape Cannot expand into given shape");
+    return {};
   }
-};
+  return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
+                                                reassociationMap);
+}
 
 class ReshapeConverterCollapseExpand
     : public OpConversionPattern<tosa::ReshapeOp> {
@@ -224,17 +217,19 @@ class ReshapeConverterCollapseExpand
           reshape, "tosa.reshape Cannot identify an intermediate shape between "
                    "the given two shapes");
     }
+    auto intermediateTy = RankedTensorType::get(
+        intermediateShape, reshape.getType().getElementType());
 
-    Value collapse = rewriter.create<tosa::ReshapeOp>(
-        reshape.getLoc(),
-        RankedTensorType::get(intermediateShape,
-                              reshape.getType().getElementType()),
-        adaptor.getInput1(), rewriter.getDenseI64ArrayAttr(intermediateShape));
-    Value expand = rewriter.create<tosa::ReshapeOp>(
-        reshape.getLoc(), resultTy, collapse,
-        rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
-    rewriter.replaceOp(reshape, expand);
+    Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
+                                    adaptor.getInput1());
+    if (!collapse)
+      return failure();
 
+    Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
+    if (!expand)
+      return failure();
+
+    rewriter.replaceOp(reshape, expand);
     return success();
   }
 };
@@ -420,10 +415,6 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<SliceConverter, PadConverter, ConcatConverter>(
       patterns->getContext());
-  patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
-                                          /*benefit=*/100);
-  patterns->add<ReshapeConverterExpand>(patterns->getContext(),
-                                        /*benefit=*/200);
-  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
-                                                /*benefit=*/300);
+
+  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8cefa64bc4c1fe..152b8857393bba 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -62,31 +62,6 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ConcatOptimization>(context);
 }
 
-struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
-  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getInput1();
-    Operation *definingOp = input.getDefiningOp();
-    if (!definingOp)
-      return failure();
-
-    if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
-      rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-          op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                            MLIRContext *context) {
-  results.add<ReshapeReshapeOptimization>(context);
-}
-
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
   auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
@@ -820,25 +795,32 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (inputTy == outputTy)
     return getInput1();
 
-  // Constants must have static shape.
-  if (!outputTy.hasStaticShape())
-    return {};
+  // reshape(reshape(x)) -> reshape(x)
+  if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
+          getInput1().getDefiningOp())) {
+    getInput1Mutable().assign(reshapeOp.getInput1());
+    return getResult();
+  }
 
-  auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  if (!operand)
-    return {};
+  // reshape(const(x)) -> const(reshape-attr(x))
+  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+    // Constants must have static shape.
+    if (!outputTy.hasStaticShape())
+      return {};
 
-  // Okay to duplicate splat constants.
-  if (operand.isSplat()) {
-    return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
-  }
+    // Okay to duplicate splat constants.
+    if (operand.isSplat())
+      return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
 
-  // Don't duplicate other constants.
-  if (!getInput1().hasOneUse())
-    return {};
+    // Don't duplicate other constants.
+    if (!getInput1().hasOneUse())
+      return {};
 
-  return operand.reshape(
-      llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+    return operand.reshape(
+        llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+  }
+
+  return {};
 }
 
 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {


        


More information about the Mlir-commits mailing list