[Mlir-commits] [mlir] [mlir] Fix bugs in expand_shape patterns after semantics changes (PR #94631)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Jun 6 09:06:57 PDT 2024
================
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
-
+ // Fold identity reshape.
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
- // Fold producer-consumer reshape ops where the operand type of the
- // producer is same as the return type of the consumer.
- auto reshapeSrcOp =
- reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
- if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
- return reshapeSrcOp.getSrc();
-
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ // Fold if the producer reshape source has the same shape with at most 1
+ // dynamic dimension.
+ auto reshapeSrcOp =
+ reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+ if (!reshapeSrcOp)
+ return nullptr;
+ auto srcType = reshapeSrcOp.getSrcType();
+ auto resultType = reshapeOp.getResultType();
+ if (srcType != resultType)
+ return nullptr;
+
+ // If the reshapes are expanding and then collapsing, the ops can be folded
+ // despite multiple dynamic dimensions.
+ if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+ return reshapeSrcOp.getSrc();
----------------
qedawkins wrote:
This is also only valid when the reassociation indices are the same.
```
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%arg1, %arg2, %arg3]
: tensor<?x?xf32> into tensor<?x?x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
https://github.com/llvm/llvm-project/pull/94631
More information about the Mlir-commits
mailing list