[Mlir-commits] [mlir] d659527 - [mlir] Use indices instead of affine maps when composing 2 reshape ops.

Alexander Belyaev llvmlistbot at llvm.org
Wed Jul 7 06:22:09 PDT 2021


Author: Alexander Belyaev
Date: 2021-07-07T15:21:46+02:00
New Revision: d6595278291425804d05985652831d2781abdf06

URL: https://github.com/llvm/llvm-project/commit/d6595278291425804d05985652831d2781abdf06
DIFF: https://github.com/llvm/llvm-project/commit/d6595278291425804d05985652831d2781abdf06.diff

LOG: [mlir] Use indices instead of affine maps when composing 2 reshape ops.

https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310

Differential Revision: https://reviews.llvm.org/D105550

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 0e9a9a4b3894..3a557a63fd26 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -28,27 +28,22 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName();
 
-/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
 /// expanding dimensions.
 ///
 /// For example,
-///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
-///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
-///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
-///                   affine_map<(d0, d1, d2) -> (d2)>]
+///   producerReassociation = [[0, 1], [2], [3, 4]]
+///   consumerReassociation = [[0, 1], [2]]
 ///
 /// is folded into
 ///
-///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
-/// TODO: Use reassociation indices instead of affine maps here.
-Optional<SmallVector<ReassociationIndices>>
-collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
-                             ArrayRef<AffineMap> mapsConsumer,
-                             MLIRContext *context);
+///   result = [[0, 1, 2], [3, 4]].
+Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
+    ArrayRef<ReassociationIndices> producerReassociations,
+    ArrayRef<ReassociationIndices> consumerReassociations,
+    MLIRContext *context);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return llvm::None when this computation
@@ -210,8 +205,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 
     ShapedType resultType = reshapeOp.getResultType();
     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
-        collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
-                                     reshapeOp.getReassociationMaps(),
+        composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
+                                     reshapeOp.getReassociationIndices(),
                                      rewriter.getContext());
     if (!reassociationIndices)
       return failure();

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b353b3195e59..4cd72e2c9ff3 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -11,6 +11,8 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 
+#include <numeric>
+
 using namespace mlir;
 
 constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
@@ -145,37 +147,40 @@ ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
   return success();
 }
 
-Optional<SmallVector<ReassociationIndices>>
-mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
-                                   ArrayRef<AffineMap> mapsConsumer,
-                                   MLIRContext *context) {
+Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
+    ArrayRef<ReassociationIndices> producerReassociations,
+    ArrayRef<ReassociationIndices> consumerReassociations,
+    MLIRContext *context) {
+  SmallVector<ReassociationIndices> composedIndices;
   // Make the producer the larger sized vector. If they are of same size, the
   // resulting reshape is not a supported reshape op.
-  if (mapsProducer.size() == mapsConsumer.size())
+  if (producerReassociations.size() == consumerReassociations.size())
     return llvm::None;
-  if (mapsProducer.size() < mapsConsumer.size())
-    std::swap(mapsProducer, mapsConsumer);
+  if (producerReassociations.size() < consumerReassociations.size())
+    std::swap(producerReassociations, consumerReassociations);
 
   // Handle the corner case of the result being a rank 0 shaped type. Return an
   // empty reassociation.
-  if (mapsConsumer.empty())
-    return SmallVector<ReassociationIndices>{};
-  if (mapsProducer.size() != mapsConsumer[0].getNumDims())
+  if (consumerReassociations.empty())
+    return composedIndices;
+
+  size_t consumerDims = std::accumulate(
+      consumerReassociations.begin(), consumerReassociations.end(), 0,
+      [](size_t all, ReassociationIndicesRef indices) {
+        return all + indices.size();
+      });
+  if (producerReassociations.size() != consumerDims)
     return llvm::None;
 
-  unsigned currDim = 0;
-  SmallVector<ReassociationIndices> reassociationMaps;
-  for (AffineMap rhs : mapsConsumer) {
+  for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
     ReassociationIndices reassociations;
-    for (AffineExpr rhsExpr : rhs.getResults()) {
-      AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
-      for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
-           i < e; ++i)
-        reassociations.push_back(currDim++);
+    for (int64_t consumerIndex : consumerIndices) {
+      for (int64_t producerIndex : producerReassociations[consumerIndex])
+        reassociations.push_back(producerIndex);
     }
-    reassociationMaps.push_back(std::move(reassociations));
+    composedIndices.push_back(std::move(reassociations));
   }
-  return reassociationMaps;
+  return composedIndices;
 }
 
 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,


        


More information about the Mlir-commits mailing list