[Mlir-commits] [mlir] 6412a13 - [mlir] Move common reshapeops-related code to ReshapeOpsUtils.h.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Jul 7 05:56:43 PDT 2021
Author: Alexander Belyaev
Date: 2021-07-07T14:56:16+02:00
New Revision: 6412a13539ab2914eed8e1df83c399b9a16e3408
URL: https://github.com/llvm/llvm-project/commit/6412a13539ab2914eed8e1df83c399b9a16e3408
DIFF: https://github.com/llvm/llvm-project/commit/6412a13539ab2914eed8e1df83c399b9a16e3408.diff
LOG: [mlir] Move common reshapeops-related code to ReshapeOpsUtils.h.
This is a first step to move (Tensor)Expand/CollapseShapeOp to tensor/memref
dialects.
Differential Revision: https://reviews.llvm.org/D105547
Added:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Utils/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 87c279a85e1b..b8a3afd48e40 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -52,16 +53,6 @@ using LoopRangeBuilder =
/// provide an op-specified hook so that Linalg ops may override the behavior.
LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
-using ReassociationIndices = SmallVector<int64_t, 2>;
-using ReassociationIndicesRef = ArrayRef<int64_t>;
-using ReassociationExprs = SmallVector<AffineExpr, 2>;
-
-/// Return the reassociations maps to use to reshape given the source type and
-/// the target type when possible. Return llvm::None when this computation
-/// failed.
-Optional<SmallVector<ReassociationIndices>>
-getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
-
/// Returns the name mangled library call name to disambiguate between
diff erent
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
/// type names:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
new file mode 100644
index 000000000000..0e9a9a4b3894
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -0,0 +1,266 @@
+//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities and common canonicalization patterns for
+// reshape operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+using ReassociationIndices = SmallVector<int64_t, 2>;
+using ReassociationIndicesRef = ArrayRef<int64_t>;
+using ReassociationExprs = SmallVector<AffineExpr, 2>;
+
+/// Attribute name for the ArrayAttr which encodes reassociation indices.
+constexpr StringRef getReassociationAttrName();
+
+/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// is a producer and other is the consumer. Only valid to use this method when
+/// both the producer and consumer are collapsing dimensions or both are
+/// expanding dimensions.
+///
+/// For example,
+/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+/// affine_map<(d0, d1, d2) -> (d2)>]
+///
+/// is folded into
+///
+/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+/// TODO: Use reassociation indices instead of affine maps here.
+Optional<SmallVector<ReassociationIndices>>
+collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
+ ArrayRef<AffineMap> mapsConsumer,
+ MLIRContext *context);
+
+/// Return the reassociations maps to use to reshape given the source type and
+/// the target type when possible. Return llvm::None when this computation
+/// failed.
+Optional<SmallVector<ReassociationIndices>>
+getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
+
+/// Return true if the reassociation specification is valid, false otherwise.
+/// When false, the `invalidIndex` integer pointer is optionally filled with the
+/// index of the offending reassociation map.
+bool isReassociationValid(ArrayRef<AffineMap> reassociation,
+ int *invalidIndex = nullptr);
+
+/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
+/// linalg::(Tensor)CollapseShapeOp.
+ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
+
+/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
+/// linalg::(Tensor)CollapseShapeOp.
+template <typename ReshapeLikeOp>
+void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
+ p << op.getOperationName() << ' ' << op.src() << " [";
+
+ llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
+ p << '[';
+ auto arrayAttr = attr.template cast<ArrayAttr>();
+ llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
+ p << attr.cast<IntegerAttr>().getInt();
+ });
+ p << ']';
+ });
+
+ p << "] ";
+ p.printOptionalAttrDict(op->getAttrs(),
+ /*elidedAttrs=*/{op.getReassociationAttrName()});
+ p << ": " << op.src().getType() << " into " << op.getType();
+}
+
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
+ ArrayRef<Attribute> operands) {
+ // Fold producer-consumer reshape ops that where the operand type of the
+ // producer is same as the return type of the consumer.
+ auto reshapeSrcOp =
+ reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+ if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
+ return reshapeSrcOp.src();
+ // Reshape of a constant can be replaced with a new constant.
+ if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
+ return elements.reshape(
+ reshapeOp.getResult().getType().template cast<ShapedType>());
+ }
+ return nullptr;
+}
+
+/// Common verifier for reshape-like types. Fills `expandedType` and
+///`collapsedType` with the proper `src` or `result` type.
+template <typename Op, typename T>
+static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
+ T collapsedType, bool isExpansion) {
+ unsigned expandedRank = expandedType.getRank();
+ unsigned collapsedRank = collapsedType.getRank();
+ if (expandedRank < collapsedRank)
+ return op.emitOpError("expected the type ")
+ << expandedType
+ << " to have higher rank than the type = " << collapsedType;
+ if (expandedRank == 0)
+ return op.emitOpError("expected non-zero memref ranks");
+ if (expandedRank == collapsedRank)
+ return op.emitOpError("expected to collapse or expand dims");
+
+ if (collapsedRank == 0) {
+ // If collapsed rank is 0, then expanded type must be static shaped and of
+ // sizes 1.
+ if (llvm::any_of(expandedType.getShape(),
+ [](int64_t dim) -> bool { return dim != 1; }))
+ return op.emitOpError("invalid to reshape tensor/memref with non-unit "
+ "extent dimensions to zero-rank tensor/memref");
+ return success();
+ }
+ if (collapsedRank != op.reassociation().size())
+ return op.emitOpError("expected rank of the collapsed type(")
+ << collapsedRank << ") to be the number of reassociation maps("
+ << op.reassociation().size() << ")";
+ auto maps = op.getReassociationMaps();
+ for (auto it : llvm::enumerate(maps))
+ if (it.value().getNumDims() != expandedRank)
+ return op.emitOpError("expected reassociation map #")
+ << it.index() << " of same rank as expanded memref("
+ << expandedRank << "), but got " << it.value().getNumDims();
+ int invalidIdx = 0;
+ if (!isReassociationValid(maps, &invalidIdx))
+ return op.emitOpError("expected reassociation map #")
+ << invalidIdx << " to be valid and contiguous";
+ return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+}
+
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+/// dimensions in the expanded shape should be
+/// a) static
+/// b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+/// corresponding dimensions in the expanded type should be dynamic. This
+/// rule is only needed with reshape operations that are expanding.
+template <typename OpTy>
+static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
+ ShapedType expandedType,
+ bool isExpandingReshape) {
+ ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
+ ArrayRef<int64_t> expandedShape = expandedType.getShape();
+ unsigned expandedDimStart = 0;
+ for (auto map : llvm::enumerate(op.getReassociationMaps())) {
+ Optional<int64_t> dynamicShape;
+ int64_t linearizedStaticShape = 1;
+ for (auto dim : llvm::enumerate(expandedShape.slice(
+ expandedDimStart, map.value().getNumResults()))) {
+ if (ShapedType::isDynamic(dim.value())) {
+ if (isExpandingReshape && dynamicShape) {
+ return op->emitOpError("invalid to have a single dimension (")
+ << map.index() << ") expanded into multiple dynamic dims ("
+ << expandedDimStart + dynamicShape.getValue() << ","
+ << expandedDimStart + dim.index() << ")";
+ }
+ dynamicShape = dim.index();
+ } else {
+ linearizedStaticShape *= dim.value();
+ }
+ }
+ if (dynamicShape) {
+ if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
+ return op->emitOpError("expected dimension ")
+ << map.index()
+ << " of collapsed type to be dynamic since one or more of the "
+ "corresponding dimensions in the expanded type is dynamic";
+ }
+ } else {
+ if (collapsedShape[map.index()] != linearizedStaticShape) {
+ return op->emitOpError("expected dimension ")
+ << map.index() << " of collapsed type to be static value of "
+ << linearizedStaticShape << " ";
+ }
+ }
+ expandedDimStart += map.value().getNumResults();
+ }
+ return success();
+}
+
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy>
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+ using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
+ if (!srcReshapeOp)
+ return failure();
+
+ ShapedType resultType = reshapeOp.getResultType();
+ Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+ collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
+ reshapeOp.getReassociationMaps(),
+ rewriter.getContext());
+ if (!reassociationIndices)
+ return failure();
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+ return success();
+ }
+};
+
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+ using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto srcReshapeOp =
+ reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+ if (!srcReshapeOp)
+ return failure();
+
+ ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
+ ShapedType intermediateType = reshapeOp.getSrcType();
+ ShapedType resultType = reshapeOp.getResultType();
+
+ // If the source reshape can be collapsed/expanded into the target reshape
+ // they can still be folded. This can only be reasoned about statically
+ // for cases where
+ // - either all shapes are static, or
+ // - The number of dynamic dimensions matches in the source of source and
+ // result with all other dimensions being 1.
+ Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+ getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
+ if (!reassociationIndices)
+ return failure();
+ bool originalOpExpands =
+ intermediateType.getRank() > srcReshapeSrcType.getRank();
+ bool resultingOpExpands =
+ resultType.getRank() > srcReshapeSrcType.getRank();
+ if (!(resultingOpExpands ^ originalOpExpands))
+ rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+ else
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+ return success();
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index f1fadccd2c2d..ceda829fcb58 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
+ MLIRDialectUtils
MLIRIR
MLIRLinalg
MLIRLinalgUtils
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d4ac19a73ba7..9399a755243c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -1120,8 +1121,7 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
: operandTy.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
- SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
- collapsedShape.size());
+ SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
@@ -1176,11 +1176,11 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
std::multiplies<int64_t>());
auto elemTy = operandTy.getElementType();
- SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
+ SmallVector<ReassociationExprs, 4> collapsingMap = {
// Use operandTy here because we need to collapse all operands
// dimensions.
getIdentityExprs(operandTy.getShape().size())};
- SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
+ SmallVector<ReassociationExprs, 4> expandingMap = {
// Use resultTy here because we need to expand to all result
// dimensions.
getIdentityExprs(resultTy.getShape().size())};
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 62e66a421d5c..66cad6eaa3cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1069,338 +1069,20 @@ OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
// ReshapeOp
//===----------------------------------------------------------------------===//
-Optional<SmallVector<ReassociationIndices>>
-mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType,
- ShapedType targetType) {
- // Make the sourceType greater rank than the targetType. If they are same
- // rank, then its an unsupported reshape op.
- if (sourceType.getRank() == targetType.getRank())
- return llvm::None;
- if (sourceType.getRank() < targetType.getRank())
- std::swap(sourceType, targetType);
-
- ArrayRef<int64_t> sourceShape = sourceType.getShape();
- ArrayRef<int64_t> targetShape = targetType.getShape();
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetType.getRank());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
-
- // If all the dimensions of the targetShape are exhausted, then the
- // remaining dims in the source shape must be all 1s. So for such cases, set
- // 1 as the target shape. The actual reassociation indices will be handled
- // later.
- int64_t currTargetShape =
- (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
- while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
- sourceDim < sourceShape.size()) {
- prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
- }
-
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
- (currTargetShape != ShapedType::kDynamicSize ||
- prodOfCollapsedDims != 1))
- return llvm::None;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamicSize &&
- sourceShape[sourceDim] != ShapedType::kDynamicSize)
- return llvm::None;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return llvm::None;
-
- currIndices.push_back(sourceDim++);
- // If the reassociation is empty but the currIndices is not, this by
- // definition is folding unit-dimensions with the result being scalar type.
- // So only append the `currIndices` if reassociation map is not empty.
- if (targetDim == targetShape.size()) {
- if (!reassociationMap.empty() && !currIndices.empty())
- reassociationMap.back().append(currIndices.begin(), currIndices.end());
- // Break out of the loops. We should be done here.
- break;
- }
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
- }
- // All the dimensions in the two shapes must have been processed.
- if (reassociationMap.size() != targetShape.size() ||
- sourceDim != sourceShape.size())
- return llvm::None;
- return reassociationMap;
-}
-
-template <typename ReshapeLikeOp>
-static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
- p << op.getOperationName() << ' ' << op.src() << " [";
-
- llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
- p << '[';
- auto arrayAttr = attr.template cast<ArrayAttr>();
- llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
- p << attr.cast<IntegerAttr>().getInt();
- });
- p << ']';
- });
-
- p << "] ";
- p.printOptionalAttrDict(op->getAttrs(),
- /*elidedAttrs=*/{op.getReassociationAttrName()});
- p << ": " << op.src().getType() << " into " << op.getType();
-}
-
static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
- print<linalg::ExpandShapeOp>(p, op);
+ ::mlir::printReshapeOp<linalg::ExpandShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
- print<linalg::CollapseShapeOp>(p, op);
+ ::mlir::printReshapeOp<linalg::CollapseShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
- print<linalg::TensorExpandShapeOp>(p, op);
+ ::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
- print<linalg::TensorCollapseShapeOp>(p, op);
-}
-
-static constexpr StringRef getReassociationAttrName() {
- return "reassociation";
-}
-
-static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
- OperationState &result) {
- // Parse the operand.
- OpAsmParser::OperandType src;
- if (parser.parseOperand(src))
- return failure();
-
- // Parse reassociation indices.
- Builder &b = parser.getBuilder();
- SmallVector<Attribute, 4> reassociation;
- if (parser.parseLSquare())
- return failure();
-
- while (true) {
- if (succeeded(parser.parseOptionalRSquare()))
- break;
- if (parser.parseLSquare())
- return failure();
- SmallVector<int64_t> indices;
- while (true) {
- int64_t index;
- if (parser.parseInteger(index))
- return failure();
- indices.push_back(index);
-
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- break;
- }
- reassociation.push_back(b.getI64ArrayAttr(indices));
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- break;
- }
-
- result.addAttribute(getReassociationAttrName(),
- b.getArrayAttr(reassociation));
-
- // Parse optional attributes.
- parser.parseOptionalAttrDict(result.attributes);
-
- // Parse types.
- Type srcType;
- Type resultType;
- if (parser.parseColon() || parser.parseType(srcType) ||
- parser.resolveOperand(src, srcType, result.operands) ||
- parser.parseKeyword("into") || parser.parseType(resultType))
- return failure();
- result.addTypes(resultType);
- return success();
-}
-
-/// Collapse reassociation maps that are used in pair of reshape ops where one
-/// is a producer and other is the consumer. Only valid to use this method when
-/// both the producer and consumer are collapsing dimensions or both are
-/// expanding dimensions.
-///
-/// For example,
-/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
-/// affine_map<(d0, d1, d2) -> (d2)>]
-///
-/// is folded into
-///
-/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-static Optional<SmallVector<ReassociationIndices>>
-collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
- ArrayRef<AffineMap> mapsConsumer,
- MLIRContext *context) {
- // Make the producer the larger sized vector. If they are of same size, the
- // resulting reshape is not a supported reshape op.
- if (mapsProducer.size() == mapsConsumer.size())
- return llvm::None;
- if (mapsProducer.size() < mapsConsumer.size())
- std::swap(mapsProducer, mapsConsumer);
-
- // Handle the corner case of the result being a rank 0 shaped type. Return an
- // empty reassociation.
- if (mapsConsumer.empty())
- return SmallVector<ReassociationIndices>{};
- if (mapsProducer.size() != mapsConsumer[0].getNumDims())
- return llvm::None;
-
- unsigned currDim = 0;
- SmallVector<ReassociationIndices> reassociationMaps;
- for (AffineMap rhs : mapsConsumer) {
- ReassociationIndices reassociations;
- for (AffineExpr rhsExpr : rhs.getResults()) {
- AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
- for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
- i < e; ++i)
- reassociations.push_back(currDim++);
- }
- reassociationMaps.push_back(std::move(reassociations));
- }
- return reassociationMaps;
-}
-
-namespace {
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
-struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
- using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
- PatternRewriter &rewriter) const override {
- auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
- if (!srcReshapeOp)
- return failure();
-
- ShapedType resultType = reshapeOp.getResultType();
- Optional<SmallVector<ReassociationIndices>> reassociationIndices =
- collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
- reshapeOp.getReassociationMaps(),
- rewriter.getContext());
- if (!reassociationIndices)
- return failure();
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
- return success();
- }
-};
-
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
- using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
- PatternRewriter &rewriter) const override {
- auto srcReshapeOp =
- reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
- if (!srcReshapeOp)
- return failure();
-
- ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
- ShapedType intermediateType = reshapeOp.getSrcType();
- ShapedType resultType = reshapeOp.getResultType();
-
- // If the source reshape can be collapsed/expanded into the target reshape
- // they can still be folded. This can only be reasoned about statically
- // for cases where
- // - either all shapes are static, or
- // - The number of dynamic dimensions matches in the source of source and
- // result with all other dimensions being 1.
- Optional<SmallVector<ReassociationIndices>> reassociationIndices =
- getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
- if (!reassociationIndices)
- return failure();
- bool originalOpExpands =
- intermediateType.getRank() > srcReshapeSrcType.getRank();
- bool resultingOpExpands =
- resultType.getRank() > srcReshapeSrcType.getRank();
- if (!(resultingOpExpands ^ originalOpExpands))
- rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
- else
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
- return success();
- }
-};
-} // namespace
-
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
- ArrayRef<Attribute> operands) {
- // Fold producer-consumer reshape ops that where the operand type of the
- // producer is same as the return type of the consumer.
- auto reshapeSrcOp =
- reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
- if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
- return reshapeSrcOp.src();
- // Reshape of a constant can be replaced with a new constant.
- if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
- return elements.reshape(
- reshapeOp.getResult().getType().template cast<ShapedType>());
- }
- return nullptr;
-}
-
-/// Return true if the reassociation specification is valid, false otherwise.
-/// When false, the `invalidIndex` integer pointer is optionally filled with the
-/// index of the offending reassociation map.
-static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
- int *invalidIndex = nullptr) {
- if (reassociation.empty())
- return true;
- unsigned nDims = reassociation[0].getNumDims();
- unsigned nextExpectedDim = 0;
- for (auto it : llvm::enumerate(reassociation)) {
- auto m = it.value();
- if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
- if (invalidIndex)
- *invalidIndex = it.index();
- return false;
- }
- for (auto e : m.getResults()) {
- auto d = e.dyn_cast<AffineDimExpr>();
- if (!d || d.getPosition() != nextExpectedDim++) {
- if (invalidIndex)
- *invalidIndex = it.index();
- return false;
- }
- }
- }
- if (nextExpectedDim != nDims) {
- if (invalidIndex)
- *invalidIndex = reassociation.size() - 1;
- return false;
- }
- return true;
+ ::mlir::printReshapeOp<linalg::TensorCollapseShapeOp>(p, op);
}
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
@@ -1736,106 +1418,12 @@ void mlir::linalg::CollapseShapeOp::build(
Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-/// dimensions in the expanded shape should be
-/// a) static
-/// b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-/// corresponding dimensions in the expanded type should be dynamic. This
-/// rule is only needed with reshape operations that are expanding.
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
- ShapedType expandedType,
- bool isExpandingReshape) {
- ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
- ArrayRef<int64_t> expandedShape = expandedType.getShape();
- unsigned expandedDimStart = 0;
- for (auto map : llvm::enumerate(op.getReassociationMaps())) {
- Optional<int64_t> dynamicShape;
- int64_t linearizedStaticShape = 1;
- for (auto dim : llvm::enumerate(expandedShape.slice(
- expandedDimStart, map.value().getNumResults()))) {
- if (ShapedType::isDynamic(dim.value())) {
- if (isExpandingReshape && dynamicShape) {
- return op->emitOpError("invalid to have a single dimension (")
- << map.index() << ") expanded into multiple dynamic dims ("
- << expandedDimStart + dynamicShape.getValue() << ","
- << expandedDimStart + dim.index() << ")";
- }
- dynamicShape = dim.index();
- } else {
- linearizedStaticShape *= dim.value();
- }
- }
- if (dynamicShape) {
- if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
- return op->emitOpError("expected dimension ")
- << map.index()
- << " of collapsed type to be dynamic since one or more of the "
- "corresponding dimensions in the expanded type is dynamic";
- }
- } else {
- if (collapsedShape[map.index()] != linearizedStaticShape) {
- return op->emitOpError("expected dimension ")
- << map.index() << " of collapsed type to be static value of "
- << linearizedStaticShape << " ";
- }
- }
- expandedDimStart += map.value().getNumResults();
- }
- return success();
-}
-
-// Common verifier for reshape-like types. Fills `expandedType` and
-// `collapsedType` with the proper `src` or `result` type.
-template <typename Op, typename T,
- bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
- std::is_same<Op, ExpandShapeOp>::value>
-static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
- T collapsedType) {
- unsigned expandedRank = expandedType.getRank();
- unsigned collapsedRank = collapsedType.getRank();
- if (expandedRank < collapsedRank)
- return op.emitOpError("expected the type ")
- << expandedType
- << " to have higher rank than the type = " << collapsedType;
- if (expandedRank == 0)
- return op.emitOpError("expected non-zero memref ranks");
- if (expandedRank == collapsedRank)
- return op.emitOpError("expected to collapse or expand dims");
-
- if (collapsedRank == 0) {
- // If collapsed rank is 0, then expanded type must be static shaped and of
- // sizes 1.
- if (llvm::any_of(expandedType.getShape(),
- [](int64_t dim) -> bool { return dim != 1; }))
- return op.emitOpError("invalid to reshape tensor/memref with non-unit "
- "extent dimensions to zero-rank tensor/memref");
- return success();
- }
- if (collapsedRank != op.reassociation().size())
- return op.emitOpError("expected rank of the collapsed type(")
- << collapsedRank << ") to be the number of reassociation maps("
- << op.reassociation().size() << ")";
- auto maps = op.getReassociationMaps();
- for (auto it : llvm::enumerate(maps))
- if (it.value().getNumDims() != expandedRank)
- return op.emitOpError("expected reassociation map #")
- << it.index() << " of same rank as expanded memref("
- << expandedRank << "), but got " << it.value().getNumDims();
- int invalidIdx = 0;
- if (!isReassociationValid(maps, &invalidIdx))
- return op.emitOpError("expected reassociation map #")
- << invalidIdx << " to be valid and contiguous";
- return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
-}
-
-template <typename TensorReshapeOp>
-static LogicalResult verifyReshapeOp(TensorReshapeOp op,
- MemRefType expandedType,
+template <typename ReshapeOp,
+ bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
+static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
MemRefType collapsedType) {
- if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+ if (failed(
+ verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
return failure();
auto maps = op.getReassociationMaps();
MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
@@ -1923,11 +1511,14 @@ void mlir::linalg::TensorExpandShapeOp::build(
getReassociationIndicesAttribute(b, reassociation));
}
-template <typename TensorReshapeOp>
+template <typename TensorReshapeOp,
+ bool isExpansion =
+ std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
RankedTensorType expandedType,
RankedTensorType collapsedType) {
- if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+ if (failed(
+ verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
return failure();
auto maps = op.getReassociationMaps();
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index 098b6b48b032..7e6d2978af41 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRDialectUtils
+ ReshapeOpsUtils.cpp
StructuredOpsUtils.cpp
StaticValueUtils.cpp
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
new file mode 100644
index 000000000000..b353b3195e59
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -0,0 +1,209 @@
+//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+
+constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
+
+Optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForReshape(ShapedType sourceType,
+ ShapedType targetType) {
+ // Make the sourceType greater rank than the targetType. If they are same
+ // rank, then its an unsupported reshape op.
+ if (sourceType.getRank() == targetType.getRank())
+ return llvm::None;
+ if (sourceType.getRank() < targetType.getRank())
+ std::swap(sourceType, targetType);
+
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> targetShape = targetType.getShape();
+ unsigned sourceDim = 0;
+ SmallVector<ReassociationIndices> reassociationMap;
+ reassociationMap.reserve(targetType.getRank());
+
+ ReassociationIndices currIndices;
+ int64_t prodOfCollapsedDims = 1;
+ while (sourceDim < sourceShape.size()) {
+ unsigned targetDim = reassociationMap.size();
+
+ // If all the dimensions of the targetShape are exhausted, then the
+ // remaining dims in the source shape must be all 1s. So for such cases, set
+ // 1 as the target shape. The actual reassociation indices will be handled
+ // later.
+ int64_t currTargetShape =
+ (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
+ while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
+ sourceDim < sourceShape.size()) {
+ prodOfCollapsedDims *= sourceShape[sourceDim];
+ currIndices.push_back(sourceDim++);
+ }
+
+ // If the current expanded dimension is dynamic, then the collapsed
+ // dimensions should also be dynamic and product of all previous unprocessed
+ // dimensions of the expanded shape should be 1.
+ if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
+ (currTargetShape != ShapedType::kDynamicSize ||
+ prodOfCollapsedDims != 1))
+ return llvm::None;
+
+ // If the collapsed dim is dynamic, the current expanded dim should also
+ // be dynamic.
+ if (currTargetShape == ShapedType::kDynamicSize &&
+ sourceShape[sourceDim] != ShapedType::kDynamicSize)
+ return llvm::None;
+
+ // For static shapes, if the product of dimensions of the expanded shape
+ // should match the collapsed dimension shape.
+ if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+ return llvm::None;
+
+ currIndices.push_back(sourceDim++);
+ // If the reassociation is empty but the currIndices is not, this by
+ // definition is folding unit-dimensions with the result being scalar type.
+ // So only append the `currIndices` if reassociation map is not empty.
+ if (targetDim == targetShape.size()) {
+ if (!reassociationMap.empty() && !currIndices.empty())
+ reassociationMap.back().append(currIndices.begin(), currIndices.end());
+ // Break out of the loops. We should be done here.
+ break;
+ }
+ reassociationMap.emplace_back(ReassociationIndices{});
+ std::swap(reassociationMap.back(), currIndices);
+ prodOfCollapsedDims = 1;
+ }
+ // All the dimensions in the two shapes must have been processed.
+ if (reassociationMap.size() != targetShape.size() ||
+ sourceDim != sourceShape.size())
+ return llvm::None;
+ return reassociationMap;
+}
+
+ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse the operand.
+ OpAsmParser::OperandType src;
+ if (parser.parseOperand(src))
+ return failure();
+
+ // Parse reassociation indices.
+ Builder &b = parser.getBuilder();
+ SmallVector<Attribute, 4> reassociation;
+ if (parser.parseLSquare())
+ return failure();
+
+ while (true) {
+ if (succeeded(parser.parseOptionalRSquare()))
+ break;
+ if (parser.parseLSquare())
+ return failure();
+ SmallVector<int64_t> indices;
+ while (true) {
+ int64_t index;
+ if (parser.parseInteger(index))
+ return failure();
+ indices.push_back(index);
+
+ if (succeeded(parser.parseOptionalComma()))
+ continue;
+ if (failed(parser.parseRSquare()))
+ return failure();
+ break;
+ }
+ reassociation.push_back(b.getI64ArrayAttr(indices));
+ if (succeeded(parser.parseOptionalComma()))
+ continue;
+ if (failed(parser.parseRSquare()))
+ return failure();
+ break;
+ }
+
+ result.addAttribute(getReassociationAttrName(),
+ b.getArrayAttr(reassociation));
+
+ // Parse optional attributes.
+ parser.parseOptionalAttrDict(result.attributes);
+
+ // Parse types.
+ Type srcType;
+ Type resultType;
+ if (parser.parseColon() || parser.parseType(srcType) ||
+ parser.resolveOperand(src, srcType, result.operands) ||
+ parser.parseKeyword("into") || parser.parseType(resultType))
+ return failure();
+ result.addTypes(resultType);
+ return success();
+}
+
+Optional<SmallVector<ReassociationIndices>>
+mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
+ ArrayRef<AffineMap> mapsConsumer,
+ MLIRContext *context) {
+ // Make the producer the larger sized vector. If they are of same size, the
+ // resulting reshape is not a supported reshape op.
+ if (mapsProducer.size() == mapsConsumer.size())
+ return llvm::None;
+ if (mapsProducer.size() < mapsConsumer.size())
+ std::swap(mapsProducer, mapsConsumer);
+
+ // Handle the corner case of the result being a rank 0 shaped type. Return an
+ // empty reassociation.
+ if (mapsConsumer.empty())
+ return SmallVector<ReassociationIndices>{};
+ if (mapsProducer.size() != mapsConsumer[0].getNumDims())
+ return llvm::None;
+
+ unsigned currDim = 0;
+ SmallVector<ReassociationIndices> reassociationMaps;
+ for (AffineMap rhs : mapsConsumer) {
+ ReassociationIndices reassociations;
+ for (AffineExpr rhsExpr : rhs.getResults()) {
+ AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
+ for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
+ i < e; ++i)
+ reassociations.push_back(currDim++);
+ }
+ reassociationMaps.push_back(std::move(reassociations));
+ }
+ return reassociationMaps;
+}
+
+bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
+ int *invalidIndex) {
+ if (reassociation.empty())
+ return true;
+ unsigned nDims = reassociation[0].getNumDims();
+ unsigned nextExpectedDim = 0;
+ for (auto it : llvm::enumerate(reassociation)) {
+ auto m = it.value();
+ if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
+ if (invalidIndex)
+ *invalidIndex = it.index();
+ return false;
+ }
+ for (auto e : m.getResults()) {
+ auto d = e.dyn_cast<AffineDimExpr>();
+ if (!d || d.getPosition() != nextExpectedDim++) {
+ if (invalidIndex)
+ *invalidIndex = it.index();
+ return false;
+ }
+ }
+ }
+ if (nextExpectedDim != nDims) {
+ if (invalidIndex)
+ *invalidIndex = reassociation.size() - 1;
+ return false;
+ }
+ return true;
+}
More information about the Mlir-commits
mailing list