[Mlir-commits] [mlir] d659527 - [mlir] Use indices instead of affine maps when composing 2 reshape ops.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Jul 7 06:22:09 PDT 2021
Author: Alexander Belyaev
Date: 2021-07-07T15:21:46+02:00
New Revision: d6595278291425804d05985652831d2781abdf06
URL: https://github.com/llvm/llvm-project/commit/d6595278291425804d05985652831d2781abdf06
DIFF: https://github.com/llvm/llvm-project/commit/d6595278291425804d05985652831d2781abdf06.diff
LOG: [mlir] Use indices instead of affine maps when composing 2 reshape ops.
https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310
Differential Revision: https://reviews.llvm.org/D105550
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 0e9a9a4b3894..3a557a63fd26 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -28,27 +28,22 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName();
-/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
/// expanding dimensions.
///
/// For example,
-/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
-/// affine_map<(d0, d1, d2) -> (d2)>]
+/// producerReassociation = [[0, 1], [2], [3, 4]]
+/// consumerReassociation = [[0, 1], [2]]
///
/// is folded into
///
-/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-/// TODO: Use reassociation indices instead of affine maps here.
-Optional<SmallVector<ReassociationIndices>>
-collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
- ArrayRef<AffineMap> mapsConsumer,
- MLIRContext *context);
+/// result = [[0, 1, 2], [3, 4]].
+Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
+ ArrayRef<ReassociationIndices> producerReassociations,
+ ArrayRef<ReassociationIndices> consumerReassociations,
+ MLIRContext *context);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return llvm::None when this computation
@@ -210,8 +205,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
ShapedType resultType = reshapeOp.getResultType();
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
- collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
- reshapeOp.getReassociationMaps(),
+ composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
+ reshapeOp.getReassociationIndices(),
rewriter.getContext());
if (!reassociationIndices)
return failure();
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b353b3195e59..4cd72e2c9ff3 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -11,6 +11,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include <numeric>
+
using namespace mlir;
constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
@@ -145,37 +147,40 @@ ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
return success();
}
-Optional<SmallVector<ReassociationIndices>>
-mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
- ArrayRef<AffineMap> mapsConsumer,
- MLIRContext *context) {
+Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
+ ArrayRef<ReassociationIndices> producerReassociations,
+ ArrayRef<ReassociationIndices> consumerReassociations,
+ MLIRContext *context) {
+ SmallVector<ReassociationIndices> composedIndices;
// Make the producer the larger sized vector. If they are of same size, the
// resulting reshape is not a supported reshape op.
- if (mapsProducer.size() == mapsConsumer.size())
+ if (producerReassociations.size() == consumerReassociations.size())
return llvm::None;
- if (mapsProducer.size() < mapsConsumer.size())
- std::swap(mapsProducer, mapsConsumer);
+ if (producerReassociations.size() < consumerReassociations.size())
+ std::swap(producerReassociations, consumerReassociations);
// Handle the corner case of the result being a rank 0 shaped type. Return an
// empty reassociation.
- if (mapsConsumer.empty())
- return SmallVector<ReassociationIndices>{};
- if (mapsProducer.size() != mapsConsumer[0].getNumDims())
+ if (consumerReassociations.empty())
+ return composedIndices;
+
+ size_t consumerDims = std::accumulate(
+ consumerReassociations.begin(), consumerReassociations.end(), 0,
+ [](size_t all, ReassociationIndicesRef indices) {
+ return all + indices.size();
+ });
+ if (producerReassociations.size() != consumerDims)
return llvm::None;
- unsigned currDim = 0;
- SmallVector<ReassociationIndices> reassociationMaps;
- for (AffineMap rhs : mapsConsumer) {
+ for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
ReassociationIndices reassociations;
- for (AffineExpr rhsExpr : rhs.getResults()) {
- AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
- for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
- i < e; ++i)
- reassociations.push_back(currDim++);
+ for (int64_t consumerIndex : consumerIndices) {
+ for (int64_t producerIndex : producerReassociations[consumerIndex])
+ reassociations.push_back(producerIndex);
}
- reassociationMaps.push_back(std::move(reassociations));
+ composedIndices.push_back(std::move(reassociations));
}
- return reassociationMaps;
+ return composedIndices;
}
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
More information about the Mlir-commits
mailing list