[Mlir-commits] [mlir] [mlir] Fix bugs in expand_shape patterns after semantics changes (PR #94631)

Quinn Dawkins llvmlistbot at llvm.org
Thu Jun 6 09:06:58 PDT 2024


================
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-
+  // Fold identity reshape.
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
 
-  // Fold producer-consumer reshape ops where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.getSrc();
-
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
 
+  // Fold if the producer reshape source has the same shape with at most 1
+  // dynamic dimension.
+  auto reshapeSrcOp =
+      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+  if (!reshapeSrcOp)
+    return nullptr;
+  auto srcType = reshapeSrcOp.getSrcType();
+  auto resultType = reshapeOp.getResultType();
+  if (srcType != resultType)
+    return nullptr;
+
+  // If the reshapes are expanding and then collapsing, the ops can be folded
+  // despite multiple dynamic dimensions.
+  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+    return reshapeSrcOp.getSrc();
+  // Otherwise, only 1 dynamic dimension is allowed.
+  if (srcType == resultType &&
+      llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+    return reshapeSrcOp.getSrc();
+  }
+
+  // Fold producer-consumer reshape ops when they are perfect inverses of each
+  // other:
+  //   1) Reassociation indices are equivalent.
+  //   2) Boundary types are equivalent.
+  //   3) No reassociations have more than 1 dynamic dimension, and reassociated
+  //      shapes are equal for each reassociation.
+  auto reassociations = reshapeOp.getReassociationIndices();
+  auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
+  if (reassociations != inverseReassociations)
+    return nullptr;
+  ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+  ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+  if (llvm::none_of(reassociations, [&](auto reInd) {
+        auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
+        auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
----------------
qedawkins wrote:

nit: Spell out the result types.

https://github.com/llvm/llvm-project/pull/94631


More information about the Mlir-commits mailing list