[Mlir-commits] [mlir] 6412a13 - [mlir] Move common reshapeops-related code to ReshapeOpsUtils.h.

Alexander Belyaev llvmlistbot at llvm.org
Wed Jul 7 05:56:43 PDT 2021


Author: Alexander Belyaev
Date: 2021-07-07T14:56:16+02:00
New Revision: 6412a13539ab2914eed8e1df83c399b9a16e3408

URL: https://github.com/llvm/llvm-project/commit/6412a13539ab2914eed8e1df83c399b9a16e3408
DIFF: https://github.com/llvm/llvm-project/commit/6412a13539ab2914eed8e1df83c399b9a16e3408.diff

LOG: [mlir] Move common reshapeops-related code to ReshapeOpsUtils.h.

This is a first step to move (Tensor)Expand/CollapseShapeOp to tensor/memref
dialects.

Differential Revision: https://reviews.llvm.org/D105547

Added: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 87c279a85e1b..b8a3afd48e40 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -52,16 +53,6 @@ using LoopRangeBuilder =
 /// provide an op-specified hook so that Linalg ops may override the behavior.
 LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
 
-using ReassociationIndices = SmallVector<int64_t, 2>;
-using ReassociationIndicesRef = ArrayRef<int64_t>;
-using ReassociationExprs = SmallVector<AffineExpr, 2>;
-
-/// Return the reassociations maps to use to reshape given the source type and
-/// the target type when possible. Return llvm::None when this computation
-/// failed.
-Optional<SmallVector<ReassociationIndices>>
-getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
-
 /// Returns the name mangled library call name to disambiguate between 
diff erent
 /// overloads at the C level. The name mangling scheme is basic and uses MLIR
 /// type names:

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
new file mode 100644
index 000000000000..0e9a9a4b3894
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -0,0 +1,266 @@
+//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities and common canonicalization patterns for
+// reshape operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+using ReassociationIndices = SmallVector<int64_t, 2>;
+using ReassociationIndicesRef = ArrayRef<int64_t>;
+using ReassociationExprs = SmallVector<AffineExpr, 2>;
+
+/// Attribute name for the ArrayAttr which encodes reassociation indices.
+constexpr StringRef getReassociationAttrName();
+
+/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// is a producer and other is the consumer. Only valid to use this method when
+/// both the producer and consumer are collapsing dimensions or both are
+/// expanding dimensions.
+///
+/// For example,
+///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+///                   affine_map<(d0, d1, d2) -> (d2)>]
+///
+/// is folded into
+///
+///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+/// TODO: Use reassociation indices instead of affine maps here.
+Optional<SmallVector<ReassociationIndices>>
+collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
+                             ArrayRef<AffineMap> mapsConsumer,
+                             MLIRContext *context);
+
+/// Return the reassociations maps to use to reshape given the source type and
+/// the target type when possible. Return llvm::None when this computation
+/// failed.
+Optional<SmallVector<ReassociationIndices>>
+getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
+
+/// Return true if the reassociation specification is valid, false otherwise.
+/// When false, the `invalidIndex` integer pointer is optionally filled with the
+/// index of the offending reassociation map.
+bool isReassociationValid(ArrayRef<AffineMap> reassociation,
+                          int *invalidIndex = nullptr);
+
+/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
+/// linalg::(Tensor)CollapseShapeOp.
+ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
+
+/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
+/// linalg::(Tensor)CollapseShapeOp.
+template <typename ReshapeLikeOp>
+void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
+  p << op.getOperationName() << ' ' << op.src() << " [";
+
+  llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
+    p << '[';
+    auto arrayAttr = attr.template cast<ArrayAttr>();
+    llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
+      p << attr.cast<IntegerAttr>().getInt();
+    });
+    p << ']';
+  });
+
+  p << "] ";
+  p.printOptionalAttrDict(op->getAttrs(),
+                          /*elidedAttrs=*/{op.getReassociationAttrName()});
+  p << ": " << op.src().getType() << " into " << op.getType();
+}
+
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
+                                  ArrayRef<Attribute> operands) {
+  // Fold producer-consumer reshape ops that where the operand type of the
+  // producer is same as the return type of the consumer.
+  auto reshapeSrcOp =
+      reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
+    return reshapeSrcOp.src();
+  // Reshape of a constant can be replaced with a new constant.
+  if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
+    return elements.reshape(
+        reshapeOp.getResult().getType().template cast<ShapedType>());
+  }
+  return nullptr;
+}
+
+/// Common verifier for reshape-like types. Fills `expandedType` and
+///`collapsedType` with the proper `src` or `result` type.
+template <typename Op, typename T>
+static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
+                                            T collapsedType, bool isExpansion) {
+  unsigned expandedRank = expandedType.getRank();
+  unsigned collapsedRank = collapsedType.getRank();
+  if (expandedRank < collapsedRank)
+    return op.emitOpError("expected the type ")
+           << expandedType
+           << " to have higher rank than the type = " << collapsedType;
+  if (expandedRank == 0)
+    return op.emitOpError("expected non-zero memref ranks");
+  if (expandedRank == collapsedRank)
+    return op.emitOpError("expected to collapse or expand dims");
+
+  if (collapsedRank == 0) {
+    // If collapsed rank is 0, then expanded type must be static shaped and of
+    // sizes 1.
+    if (llvm::any_of(expandedType.getShape(),
+                     [](int64_t dim) -> bool { return dim != 1; }))
+      return op.emitOpError("invalid to reshape tensor/memref with non-unit "
+                            "extent dimensions to zero-rank tensor/memref");
+    return success();
+  }
+  if (collapsedRank != op.reassociation().size())
+    return op.emitOpError("expected rank of the collapsed type(")
+           << collapsedRank << ") to be the number of reassociation maps("
+           << op.reassociation().size() << ")";
+  auto maps = op.getReassociationMaps();
+  for (auto it : llvm::enumerate(maps))
+    if (it.value().getNumDims() != expandedRank)
+      return op.emitOpError("expected reassociation map #")
+             << it.index() << " of same rank as expanded memref("
+             << expandedRank << "), but got " << it.value().getNumDims();
+  int invalidIdx = 0;
+  if (!isReassociationValid(maps, &invalidIdx))
+    return op.emitOpError("expected reassociation map #")
+           << invalidIdx << " to be valid and contiguous";
+  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+}
+
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+///    dimensions in the expanded shape should be
+///    a) static
+///    b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+///    corresponding dimensions in the expanded type should be dynamic. This
+///    rule is only needed with reshape operations that are expanding.
+template <typename OpTy>
+static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
+                                             ShapedType expandedType,
+                                             bool isExpandingReshape) {
+  ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
+  ArrayRef<int64_t> expandedShape = expandedType.getShape();
+  unsigned expandedDimStart = 0;
+  for (auto map : llvm::enumerate(op.getReassociationMaps())) {
+    Optional<int64_t> dynamicShape;
+    int64_t linearizedStaticShape = 1;
+    for (auto dim : llvm::enumerate(expandedShape.slice(
+             expandedDimStart, map.value().getNumResults()))) {
+      if (ShapedType::isDynamic(dim.value())) {
+        if (isExpandingReshape && dynamicShape) {
+          return op->emitOpError("invalid to have a single dimension (")
+                 << map.index() << ") expanded into multiple dynamic dims ("
+                 << expandedDimStart + dynamicShape.getValue() << ","
+                 << expandedDimStart + dim.index() << ")";
+        }
+        dynamicShape = dim.index();
+      } else {
+        linearizedStaticShape *= dim.value();
+      }
+    }
+    if (dynamicShape) {
+      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
+        return op->emitOpError("expected dimension ")
+               << map.index()
+               << " of collapsed type to be dynamic since one or more of the "
+                  "corresponding dimensions in the expanded type is dynamic";
+      }
+    } else {
+      if (collapsedShape[map.index()] != linearizedStaticShape) {
+        return op->emitOpError("expected dimension ")
+               << map.index() << " of collapsed type to be static value of "
+               << linearizedStaticShape << " ";
+      }
+    }
+    expandedDimStart += map.value().getNumResults();
+  }
+  return success();
+}
+
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy>
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
+    if (!srcReshapeOp)
+      return failure();
+
+    ShapedType resultType = reshapeOp.getResultType();
+    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+        collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
+                                     reshapeOp.getReassociationMaps(),
+                                     rewriter.getContext());
+    if (!reassociationIndices)
+      return failure();
+    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+        reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+    return success();
+  }
+};
+
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcReshapeOp =
+        reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+    if (!srcReshapeOp)
+      return failure();
+
+    ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
+    ShapedType intermediateType = reshapeOp.getSrcType();
+    ShapedType resultType = reshapeOp.getResultType();
+
+    // If the source reshape can be collapsed/expanded into the target reshape
+    // they can still be folded. This can only be reasoned about statically
+    // for cases where
+    // - either all shapes are static, or
+    // - The number of dynamic dimensions matches in the source of source and
+    //   result with all other dimensions being 1.
+    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+        getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
+    if (!reassociationIndices)
+      return failure();
+    bool originalOpExpands =
+        intermediateType.getRank() > srcReshapeSrcType.getRank();
+    bool resultingOpExpands =
+        resultType.getRank() > srcReshapeSrcType.getRank();
+    if (!(resultingOpExpands ^ originalOpExpands))
+      rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+    else
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+    return success();
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

diff  --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index f1fadccd2c2d..ceda829fcb58 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   MLIRConversionPassIncGen
 
   LINK_LIBS PUBLIC
+  MLIRDialectUtils
   MLIRIR
   MLIRLinalg
   MLIRLinalgUtils

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d4ac19a73ba7..9399a755243c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -1120,8 +1121,7 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
         (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
                                                   : operandTy.getShape());
     unsigned currSrcDim = 0, currDstDim = 0;
-    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
-        collapsedShape.size());
+    SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
 
     // First scan all dimensions in the source shapes to see whether we have a
     // perfect case where consecutive dimensions in source are collapsed. For
@@ -1176,11 +1176,11 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
           std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
                           std::multiplies<int64_t>());
       auto elemTy = operandTy.getElementType();
-      SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
+      SmallVector<ReassociationExprs, 4> collapsingMap = {
           // Use operandTy here because we need to collapse all operands
           // dimensions.
           getIdentityExprs(operandTy.getShape().size())};
-      SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
+      SmallVector<ReassociationExprs, 4> expandingMap = {
           // Use resultTy here because we need to expand to all result
           // dimensions.
           getIdentityExprs(resultTy.getShape().size())};

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 62e66a421d5c..66cad6eaa3cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1069,338 +1069,20 @@ OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
 // ReshapeOp
 //===----------------------------------------------------------------------===//
 
-Optional<SmallVector<ReassociationIndices>>
-mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType,
-                                                ShapedType targetType) {
-  // Make the sourceType greater rank than the targetType. If they are same
-  // rank, then its an unsupported reshape op.
-  if (sourceType.getRank() == targetType.getRank())
-    return llvm::None;
-  if (sourceType.getRank() < targetType.getRank())
-    std::swap(sourceType, targetType);
-
-  ArrayRef<int64_t> sourceShape = sourceType.getShape();
-  ArrayRef<int64_t> targetShape = targetType.getShape();
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetType.getRank());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-
-    // If all the dimensions of the targetShape are exhausted, then the
-    // remaining dims in the source shape must be all 1s. So for such cases, set
-    // 1 as the target shape. The actual reassociation indices will be handled
-    // later.
-    int64_t currTargetShape =
-        (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
-    while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
-           sourceDim < sourceShape.size()) {
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
-    }
-
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
-        (currTargetShape != ShapedType::kDynamicSize ||
-         prodOfCollapsedDims != 1))
-      return llvm::None;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamicSize &&
-        sourceShape[sourceDim] != ShapedType::kDynamicSize)
-      return llvm::None;
-
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return llvm::None;
-
-    currIndices.push_back(sourceDim++);
-    // If the reassociation is empty but the currIndices is not, this by
-    // definition is folding unit-dimensions with the result being scalar type.
-    // So only append the `currIndices` if reassociation map is not empty.
-    if (targetDim == targetShape.size()) {
-      if (!reassociationMap.empty() && !currIndices.empty())
-        reassociationMap.back().append(currIndices.begin(), currIndices.end());
-      // Break out of the loops. We should be done here.
-      break;
-    }
-    reassociationMap.emplace_back(ReassociationIndices{});
-    std::swap(reassociationMap.back(), currIndices);
-    prodOfCollapsedDims = 1;
-  }
-  // All the dimensions in the two shapes must have been processed.
-  if (reassociationMap.size() != targetShape.size() ||
-      sourceDim != sourceShape.size())
-    return llvm::None;
-  return reassociationMap;
-}
-
-template <typename ReshapeLikeOp>
-static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
-  p << op.getOperationName() << ' ' << op.src() << " [";
-
-  llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
-    p << '[';
-    auto arrayAttr = attr.template cast<ArrayAttr>();
-    llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
-      p << attr.cast<IntegerAttr>().getInt();
-    });
-    p << ']';
-  });
-
-  p << "] ";
-  p.printOptionalAttrDict(op->getAttrs(),
-                          /*elidedAttrs=*/{op.getReassociationAttrName()});
-  p << ": " << op.src().getType() << " into " << op.getType();
-}
-
 static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
-  print<linalg::ExpandShapeOp>(p, op);
+  ::mlir::printReshapeOp<linalg::ExpandShapeOp>(p, op);
 }
 
 static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
-  print<linalg::CollapseShapeOp>(p, op);
+  ::mlir::printReshapeOp<linalg::CollapseShapeOp>(p, op);
 }
 
 static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
-  print<linalg::TensorExpandShapeOp>(p, op);
+  ::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
 }
 
 static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
-  print<linalg::TensorCollapseShapeOp>(p, op);
-}
-
-static constexpr StringRef getReassociationAttrName() {
-  return "reassociation";
-}
-
-static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
-                                      OperationState &result) {
-  // Parse the operand.
-  OpAsmParser::OperandType src;
-  if (parser.parseOperand(src))
-    return failure();
-
-  // Parse reassociation indices.
-  Builder &b = parser.getBuilder();
-  SmallVector<Attribute, 4> reassociation;
-  if (parser.parseLSquare())
-    return failure();
-
-  while (true) {
-    if (succeeded(parser.parseOptionalRSquare()))
-      break;
-    if (parser.parseLSquare())
-      return failure();
-    SmallVector<int64_t> indices;
-    while (true) {
-      int64_t index;
-      if (parser.parseInteger(index))
-        return failure();
-      indices.push_back(index);
-
-      if (succeeded(parser.parseOptionalComma()))
-        continue;
-      if (failed(parser.parseRSquare()))
-        return failure();
-      break;
-    }
-    reassociation.push_back(b.getI64ArrayAttr(indices));
-    if (succeeded(parser.parseOptionalComma()))
-      continue;
-    if (failed(parser.parseRSquare()))
-      return failure();
-    break;
-  }
-
-  result.addAttribute(getReassociationAttrName(),
-                      b.getArrayAttr(reassociation));
-
-  // Parse optional attributes.
-  parser.parseOptionalAttrDict(result.attributes);
-
-  // Parse types.
-  Type srcType;
-  Type resultType;
-  if (parser.parseColon() || parser.parseType(srcType) ||
-      parser.resolveOperand(src, srcType, result.operands) ||
-      parser.parseKeyword("into") || parser.parseType(resultType))
-    return failure();
-  result.addTypes(resultType);
-  return success();
-}
-
-/// Collapse reassociation maps that are used in pair of reshape ops where one
-/// is a producer and other is the consumer. Only valid to use this method when
-/// both the producer and consumer are collapsing dimensions or both are
-/// expanding dimensions.
-///
-/// For example,
-///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
-///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
-///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
-///                   affine_map<(d0, d1, d2) -> (d2)>]
-///
-/// is folded into
-///
-///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-static Optional<SmallVector<ReassociationIndices>>
-collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
-                             ArrayRef<AffineMap> mapsConsumer,
-                             MLIRContext *context) {
-  // Make the producer the larger sized vector. If they are of same size, the
-  // resulting reshape is not a supported reshape op.
-  if (mapsProducer.size() == mapsConsumer.size())
-    return llvm::None;
-  if (mapsProducer.size() < mapsConsumer.size())
-    std::swap(mapsProducer, mapsConsumer);
-
-  // Handle the corner case of the result being a rank 0 shaped type. Return an
-  // empty reassociation.
-  if (mapsConsumer.empty())
-    return SmallVector<ReassociationIndices>{};
-  if (mapsProducer.size() != mapsConsumer[0].getNumDims())
-    return llvm::None;
-
-  unsigned currDim = 0;
-  SmallVector<ReassociationIndices> reassociationMaps;
-  for (AffineMap rhs : mapsConsumer) {
-    ReassociationIndices reassociations;
-    for (AffineExpr rhsExpr : rhs.getResults()) {
-      AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
-      for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
-           i < e; ++i)
-        reassociations.push_back(currDim++);
-    }
-    reassociationMaps.push_back(std::move(reassociations));
-  }
-  return reassociationMaps;
-}
-
-namespace {
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
-struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
-  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
-    if (!srcReshapeOp)
-      return failure();
-
-    ShapedType resultType = reshapeOp.getResultType();
-    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
-        collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
-                                     reshapeOp.getReassociationMaps(),
-                                     rewriter.getContext());
-    if (!reassociationIndices)
-      return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
-    return success();
-  }
-};
-
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
-  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcReshapeOp =
-        reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
-    if (!srcReshapeOp)
-      return failure();
-
-    ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
-    ShapedType intermediateType = reshapeOp.getSrcType();
-    ShapedType resultType = reshapeOp.getResultType();
-
-    // If the source reshape can be collapsed/expanded into the target reshape
-    // they can still be folded. This can only be reasoned about statically
-    // for cases where
-    // - either all shapes are static, or
-    // - The number of dynamic dimensions matches in the source of source and
-    //   result with all other dimensions being 1.
-    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
-        getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
-    if (!reassociationIndices)
-      return failure();
-    bool originalOpExpands =
-        intermediateType.getRank() > srcReshapeSrcType.getRank();
-    bool resultingOpExpands =
-        resultType.getRank() > srcReshapeSrcType.getRank();
-    if (!(resultingOpExpands ^ originalOpExpands))
-      rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
-    else
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
-    return success();
-  }
-};
-} // namespace
-
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
-                                  ArrayRef<Attribute> operands) {
-  // Fold producer-consumer reshape ops that where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.src();
-  // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
-    return elements.reshape(
-        reshapeOp.getResult().getType().template cast<ShapedType>());
-  }
-  return nullptr;
-}
-
-/// Return true if the reassociation specification is valid, false otherwise.
-/// When false, the `invalidIndex` integer pointer is optionally filled with the
-/// index of the offending reassociation map.
-static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
-                                 int *invalidIndex = nullptr) {
-  if (reassociation.empty())
-    return true;
-  unsigned nDims = reassociation[0].getNumDims();
-  unsigned nextExpectedDim = 0;
-  for (auto it : llvm::enumerate(reassociation)) {
-    auto m = it.value();
-    if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
-      if (invalidIndex)
-        *invalidIndex = it.index();
-      return false;
-    }
-    for (auto e : m.getResults()) {
-      auto d = e.dyn_cast<AffineDimExpr>();
-      if (!d || d.getPosition() != nextExpectedDim++) {
-        if (invalidIndex)
-          *invalidIndex = it.index();
-        return false;
-      }
-    }
-  }
-  if (nextExpectedDim != nDims) {
-    if (invalidIndex)
-      *invalidIndex = reassociation.size() - 1;
-    return false;
-  }
-  return true;
+  ::mlir::printReshapeOp<linalg::TensorCollapseShapeOp>(p, op);
 }
 
 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
@@ -1736,106 +1418,12 @@ void mlir::linalg::CollapseShapeOp::build(
 
 Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
 
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-///    dimensions in the expanded shape should be
-///    a) static
-///    b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-///    corresponding dimensions in the expanded type should be dynamic. This
-///    rule is only needed with reshape operations that are expanding.
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
-                                             ShapedType expandedType,
-                                             bool isExpandingReshape) {
-  ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
-  ArrayRef<int64_t> expandedShape = expandedType.getShape();
-  unsigned expandedDimStart = 0;
-  for (auto map : llvm::enumerate(op.getReassociationMaps())) {
-    Optional<int64_t> dynamicShape;
-    int64_t linearizedStaticShape = 1;
-    for (auto dim : llvm::enumerate(expandedShape.slice(
-             expandedDimStart, map.value().getNumResults()))) {
-      if (ShapedType::isDynamic(dim.value())) {
-        if (isExpandingReshape && dynamicShape) {
-          return op->emitOpError("invalid to have a single dimension (")
-                 << map.index() << ") expanded into multiple dynamic dims ("
-                 << expandedDimStart + dynamicShape.getValue() << ","
-                 << expandedDimStart + dim.index() << ")";
-        }
-        dynamicShape = dim.index();
-      } else {
-        linearizedStaticShape *= dim.value();
-      }
-    }
-    if (dynamicShape) {
-      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
-        return op->emitOpError("expected dimension ")
-               << map.index()
-               << " of collapsed type to be dynamic since one or more of the "
-                  "corresponding dimensions in the expanded type is dynamic";
-      }
-    } else {
-      if (collapsedShape[map.index()] != linearizedStaticShape) {
-        return op->emitOpError("expected dimension ")
-               << map.index() << " of collapsed type to be static value of "
-               << linearizedStaticShape << " ";
-      }
-    }
-    expandedDimStart += map.value().getNumResults();
-  }
-  return success();
-}
-
-// Common verifier for reshape-like types. Fills `expandedType` and
-// `collapsedType` with the proper `src` or `result` type.
-template <typename Op, typename T,
-          bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
-                             std::is_same<Op, ExpandShapeOp>::value>
-static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
-                                            T collapsedType) {
-  unsigned expandedRank = expandedType.getRank();
-  unsigned collapsedRank = collapsedType.getRank();
-  if (expandedRank < collapsedRank)
-    return op.emitOpError("expected the type ")
-           << expandedType
-           << " to have higher rank than the type = " << collapsedType;
-  if (expandedRank == 0)
-    return op.emitOpError("expected non-zero memref ranks");
-  if (expandedRank == collapsedRank)
-    return op.emitOpError("expected to collapse or expand dims");
-
-  if (collapsedRank == 0) {
-    // If collapsed rank is 0, then expanded type must be static shaped and of
-    // sizes 1.
-    if (llvm::any_of(expandedType.getShape(),
-                     [](int64_t dim) -> bool { return dim != 1; }))
-      return op.emitOpError("invalid to reshape tensor/memref with non-unit "
-                            "extent dimensions to zero-rank tensor/memref");
-    return success();
-  }
-  if (collapsedRank != op.reassociation().size())
-    return op.emitOpError("expected rank of the collapsed type(")
-           << collapsedRank << ") to be the number of reassociation maps("
-           << op.reassociation().size() << ")";
-  auto maps = op.getReassociationMaps();
-  for (auto it : llvm::enumerate(maps))
-    if (it.value().getNumDims() != expandedRank)
-      return op.emitOpError("expected reassociation map #")
-             << it.index() << " of same rank as expanded memref("
-             << expandedRank << "), but got " << it.value().getNumDims();
-  int invalidIdx = 0;
-  if (!isReassociationValid(maps, &invalidIdx))
-    return op.emitOpError("expected reassociation map #")
-           << invalidIdx << " to be valid and contiguous";
-  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
-}
-
-template <typename TensorReshapeOp>
-static LogicalResult verifyReshapeOp(TensorReshapeOp op,
-                                     MemRefType expandedType,
+template <typename ReshapeOp,
+          bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
+static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
                                      MemRefType collapsedType) {
-  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+  if (failed(
+          verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
     return failure();
   auto maps = op.getReassociationMaps();
   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
@@ -1923,11 +1511,14 @@ void mlir::linalg::TensorExpandShapeOp::build(
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-template <typename TensorReshapeOp>
+template <typename TensorReshapeOp,
+          bool isExpansion =
+              std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value>
 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
                                            RankedTensorType expandedType,
                                            RankedTensorType collapsedType) {
-  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+  if (failed(
+          verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
     return failure();
 
   auto maps = op.getReassociationMaps();

diff  --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index 098b6b48b032..7e6d2978af41 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRDialectUtils
+  ReshapeOpsUtils.cpp
   StructuredOpsUtils.cpp
   StaticValueUtils.cpp
 

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
new file mode 100644
index 000000000000..b353b3195e59
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -0,0 +1,209 @@
+//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+
+constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
+
+Optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForReshape(ShapedType sourceType,
+                                        ShapedType targetType) {
+  // Make the sourceType greater rank than the targetType. If they are same
+  // rank, then its an unsupported reshape op.
+  if (sourceType.getRank() == targetType.getRank())
+    return llvm::None;
+  if (sourceType.getRank() < targetType.getRank())
+    std::swap(sourceType, targetType);
+
+  ArrayRef<int64_t> sourceShape = sourceType.getShape();
+  ArrayRef<int64_t> targetShape = targetType.getShape();
+  unsigned sourceDim = 0;
+  SmallVector<ReassociationIndices> reassociationMap;
+  reassociationMap.reserve(targetType.getRank());
+
+  ReassociationIndices currIndices;
+  int64_t prodOfCollapsedDims = 1;
+  while (sourceDim < sourceShape.size()) {
+    unsigned targetDim = reassociationMap.size();
+
+    // If all the dimensions of the targetShape are exhausted, then the
+    // remaining dims in the source shape must be all 1s. So for such cases, set
+    // 1 as the target shape. The actual reassociation indices will be handled
+    // later.
+    int64_t currTargetShape =
+        (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
+    while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
+           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
+           sourceDim < sourceShape.size()) {
+      prodOfCollapsedDims *= sourceShape[sourceDim];
+      currIndices.push_back(sourceDim++);
+    }
+
+    // If the current expanded dimension is dynamic, then the collapsed
+    // dimensions should also be dynamic and product of all previous unprocessed
+    // dimensions of the expanded shape should be 1.
+    if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
+        (currTargetShape != ShapedType::kDynamicSize ||
+         prodOfCollapsedDims != 1))
+      return llvm::None;
+
+    // If the collapsed dim is dynamic, the current expanded dim should also
+    // be dynamic.
+    if (currTargetShape == ShapedType::kDynamicSize &&
+        sourceShape[sourceDim] != ShapedType::kDynamicSize)
+      return llvm::None;
+
+    // For static shapes, if the product of dimensions of the expanded shape
+    // should match the collapsed dimension shape.
+    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+      return llvm::None;
+
+    currIndices.push_back(sourceDim++);
+    // If the reassociation is empty but the currIndices is not, this by
+    // definition is folding unit-dimensions with the result being scalar type.
+    // So only append the `currIndices` if reassociation map is not empty.
+    if (targetDim == targetShape.size()) {
+      if (!reassociationMap.empty() && !currIndices.empty())
+        reassociationMap.back().append(currIndices.begin(), currIndices.end());
+      // Break out of the loops. We should be done here.
+      break;
+    }
+    reassociationMap.emplace_back(ReassociationIndices{});
+    std::swap(reassociationMap.back(), currIndices);
+    prodOfCollapsedDims = 1;
+  }
+  // All the dimensions in the two shapes must have been processed.
+  if (reassociationMap.size() != targetShape.size() ||
+      sourceDim != sourceShape.size())
+    return llvm::None;
+  return reassociationMap;
+}
+
+ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
+                                     OperationState &result) {
+  // Parse the operand.
+  OpAsmParser::OperandType src;
+  if (parser.parseOperand(src))
+    return failure();
+
+  // Parse reassociation indices.
+  Builder &b = parser.getBuilder();
+  SmallVector<Attribute, 4> reassociation;
+  if (parser.parseLSquare())
+    return failure();
+
+  while (true) {
+    if (succeeded(parser.parseOptionalRSquare()))
+      break;
+    if (parser.parseLSquare())
+      return failure();
+    SmallVector<int64_t> indices;
+    while (true) {
+      int64_t index;
+      if (parser.parseInteger(index))
+        return failure();
+      indices.push_back(index);
+
+      if (succeeded(parser.parseOptionalComma()))
+        continue;
+      if (failed(parser.parseRSquare()))
+        return failure();
+      break;
+    }
+    reassociation.push_back(b.getI64ArrayAttr(indices));
+    if (succeeded(parser.parseOptionalComma()))
+      continue;
+    if (failed(parser.parseRSquare()))
+      return failure();
+    break;
+  }
+
+  result.addAttribute(getReassociationAttrName(),
+                      b.getArrayAttr(reassociation));
+
+  // Parse optional attributes.
+  parser.parseOptionalAttrDict(result.attributes);
+
+  // Parse types.
+  Type srcType;
+  Type resultType;
+  if (parser.parseColon() || parser.parseType(srcType) ||
+      parser.resolveOperand(src, srcType, result.operands) ||
+      parser.parseKeyword("into") || parser.parseType(resultType))
+    return failure();
+  result.addTypes(resultType);
+  return success();
+}
+
+Optional<SmallVector<ReassociationIndices>>
+mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
+                                   ArrayRef<AffineMap> mapsConsumer,
+                                   MLIRContext *context) {
+  // Make the producer the larger sized vector. If they are of same size, the
+  // resulting reshape is not a supported reshape op.
+  if (mapsProducer.size() == mapsConsumer.size())
+    return llvm::None;
+  if (mapsProducer.size() < mapsConsumer.size())
+    std::swap(mapsProducer, mapsConsumer);
+
+  // Handle the corner case of the result being a rank 0 shaped type. Return an
+  // empty reassociation.
+  if (mapsConsumer.empty())
+    return SmallVector<ReassociationIndices>{};
+  if (mapsProducer.size() != mapsConsumer[0].getNumDims())
+    return llvm::None;
+
+  unsigned currDim = 0;
+  SmallVector<ReassociationIndices> reassociationMaps;
+  for (AffineMap rhs : mapsConsumer) {
+    ReassociationIndices reassociations;
+    for (AffineExpr rhsExpr : rhs.getResults()) {
+      AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
+      for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
+           i < e; ++i)
+        reassociations.push_back(currDim++);
+    }
+    reassociationMaps.push_back(std::move(reassociations));
+  }
+  return reassociationMaps;
+}
+
+bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
+                                int *invalidIndex) {
+  if (reassociation.empty())
+    return true;
+  unsigned nDims = reassociation[0].getNumDims();
+  unsigned nextExpectedDim = 0;
+  for (auto it : llvm::enumerate(reassociation)) {
+    auto m = it.value();
+    if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
+      if (invalidIndex)
+        *invalidIndex = it.index();
+      return false;
+    }
+    for (auto e : m.getResults()) {
+      auto d = e.dyn_cast<AffineDimExpr>();
+      if (!d || d.getPosition() != nextExpectedDim++) {
+        if (invalidIndex)
+          *invalidIndex = it.index();
+        return false;
+      }
+    }
+  }
+  if (nextExpectedDim != nDims) {
+    if (invalidIndex)
+      *invalidIndex = reassociation.size() - 1;
+    return false;
+  }
+  return true;
+}


        


More information about the Mlir-commits mailing list