[Mlir-commits] [mlir] [MLIR][Tensor, MemRef] Fold expand_shape and collapse_shape if identity (PR #80658)

James Newling llvmlistbot at llvm.org
Mon Feb 5 05:33:02 PST 2024


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/80658

>From 4417550d0e1feadf8ce92bd180754c64fa78482c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 01:48:47 -0800
Subject: [PATCH 1/3] let identity versions pass

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp |  8 --------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ------------
 mlir/test/Dialect/MemRef/invalid.mlir    | 14 --------------
 mlir/test/Dialect/Tensor/invalid.mlir    | 14 --------------
 4 files changed, 48 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e..dc2ad9f54d0b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2223,10 +2223,6 @@ LogicalResult ExpandShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
-  if (srcType.getRank() >= resultType.getRank())
-    return emitOpError("expected rank expansion, but found source rank ")
-           << srcType.getRank() << " >= result rank " << resultType.getRank();
-
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
                                   resultType.getShape(),
@@ -2377,10 +2373,6 @@ LogicalResult CollapseShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
-  if (srcType.getRank() <= resultType.getRank())
-    return emitOpError("expected rank reduction, but found source rank ")
-           << srcType.getRank() << " <= result rank " << resultType.getRank();
-
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
                                   srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b2fe58099b2fb..ef8545016d3dc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1656,22 +1656,10 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
-  if (srcType.getRank() >= resultType.getRank())
-    return emitOpError("expected rank expansion, but found source rank ")
-           << srcType.getRank() << " >= result rank " << resultType.getRank();
-
   return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
 }
 
 LogicalResult CollapseShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
-  if (srcType.getRank() <= resultType.getRank())
-    return emitOpError("expected rank reduction, but found source rank ")
-           << srcType.getRank() << " <= result rank " << resultType.getRank();
-
   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 7bb7a2affcbd1..53841fa836da7 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -415,20 +415,6 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
-  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
-  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
-  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
-  %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
 func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
   // expected-error @+1 {{op reassociation index 2 is out of bounds}}
   %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 735e5146e9dbc..39cd3788cb081 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -343,20 +343,6 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
 
 // -----
 
-func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
-  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
-  %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
-  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
-  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
 func.func @rank(%0: f32) {
   // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
   "tensor.rank"(%0): (f32)->index

>From f835843f2388477e345bfe7242d0371d9887d071 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 02:49:15 -0800
Subject: [PATCH 2/3] test

---
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  2 --
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  4 ++++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 ++++
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 19 +++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 21 ++++++++++++++++++-
 5 files changed, 47 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f27..3a672a1cc6060 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -111,8 +111,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
            << " to have higher rank than the type = " << collapsedType;
   if (expandedRank == 0)
     return op.emitOpError("expected non-zero memref ranks");
-  if (expandedRank == collapsedRank)
-    return op.emitOpError("expected to collapse or expand dims");
 
   if (collapsedRank == 0) {
     // If collapsed rank is 0, then expanded type must be static shaped and of
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dc2ad9f54d0b1..b79022d675bd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,11 +2447,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ef8545016d3dc..591cbcde80084 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,11 +1860,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034..8c0b765c1b123 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,5 +1,24 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
 
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
+  %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
+  return %0 : memref<5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
+  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
+  return %0 : memref<5x4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ed964071358ac..9b31d858d63c5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,5 +1,24 @@
 // RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
 
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
+  return %0 : tensor<5x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @tensor_bitcast_chain_ok
 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -1977,7 +1996,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
 
 // Chain: NC -> NCnc -> NCnc -> NC
 // CHECK: func.func @unpack_pack(
-// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, 
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
 // CHECK: return %[[T]] : tensor<128x128xf32>
 func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
   %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>

>From 682a594bed444f609f21795ac5478f0b80503c8a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 05:32:48 -0800
Subject: [PATCH 3/3] updates. additional testing, doc corrections, remove
 rank-0 special case handling

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 39 +++++-------
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      | 63 +++++++++----------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  4 --
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 --
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 10 +++
 mlir/test/Dialect/MemRef/invalid.mlir         | 28 +++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 18 ++++++
 8 files changed, 100 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609..39e66cd9e6e5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
   let summary = "non-blocking DMA operation that starts a transfer";
   let description = [{
     Syntax:
-    
+
     ```
     operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
                    ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index eb0c79c01bee1..8134c83022eb3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
   let summary = "operation to produce a tensor with a higher rank";
   let description = [{
-    The `tensor.expand_shape` op produces a new tensor with a higher
-    rank whose sizes are a reassociation of the original `src`.
+    The `tensor.expand_shape` op produces a tensor of higher (or equal)
+    rank than the operand `src` whose dimension sizes are a reassociation of
+    `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.
-
-    The verification rule is that the reassociation maps are applied to the
-    result tensor with the higher rank to obtain the operand tensor with the
-    smaller rank.
+    A reassociation is defined as a continuous grouping of dimensions. It is
+    represented with an array of DenseI64ArrayAttr attribute. Entries in the
+    array are referred to as reassociation maps.
 
-    The operand tensor type of a reshape can be zero-ranked if the result
-    tensor type is statically shaped with all dimensions being unit extent. In
-    such cases the reassociation map is empty.
+    The reassociation maps are applied to the result shape to obtain the operand
+    shape.
 
-    Examples:
+    Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
   let description = [{
-    The `tensor.collapse_shape` op produces a new tensor with a smaller
-    rank whose sizes are a reassociation of the original `src`.
+    The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
+    rank whose dimension sizes are a reassociation of the original `src` dimensions.
 
     A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.
+    represented by an array of DenseI64ArrayAttr attribute. The reassociation
+    maps are applied to the operand shape to obtain the result shape.
 
-    The verification rule is that the reassociation maps are applied to the
-    operand tensor with the higher rank to obtain the result tensor with the
-    smaller rank.
 
-    The result tensor type of a reshape can be zero-ranked if the operand
-    tensor type is statically shaped with all dimensions being unit extent. In
-    such case the reassociation map is empty.
-
-    Examples:
+    Example:
 
     ```mlir
     // Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     and optionally transposes the tiled source tensor dimensions.
 
     `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
-    being tiled, where `0 < k <= n`. The order of the dimensions matters: 
+    being tiled, where `0 < k <= n`. The order of the dimensions matters:
      - The tiled dimensions (of size `inner_tiles`) are added to the end of the result
     tensor in the order in which they appear in `inner_dims_pos`.
      - `inner_dims_pos[i]` specifies the source tensor dimension tiled by
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a672a1cc6060..7192b5846cb3d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-  // Fold producer-consumer reshape ops that where the operand type of the
+
+  if (reshapeOp.getSrcType() == reshapeOp.getType())
+    return reshapeOp.getSrc();
+
+  // Fold producer-consumer reshape ops where the operand type of the
   // producer is same as the return type of the consumer.
   auto reshapeSrcOp =
       reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
   if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.getSrc();
+
   // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
-  }
+
   return nullptr;
 }
 
@@ -103,39 +108,37 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
 template <typename Op, typename T>
 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
                                             T collapsedType, bool isExpansion) {
+
   unsigned expandedRank = expandedType.getRank();
   unsigned collapsedRank = collapsedType.getRank();
   if (expandedRank < collapsedRank)
-    return op.emitOpError("expected the type ")
-           << expandedType
-           << " to have higher rank than the type = " << collapsedType;
-  if (expandedRank == 0)
-    return op.emitOpError("expected non-zero memref ranks");
-
-  if (collapsedRank == 0) {
-    // If collapsed rank is 0, then expanded type must be static shaped and of
-    // sizes 1.
-    if (llvm::any_of(expandedType.getShape(),
-                     [](int64_t dim) -> bool { return dim != 1; }))
-      return op.emitOpError("invalid to reshape tensor/memref with non-unit "
-                            "extent dimensions to zero-rank tensor/memref");
-    return success();
-  }
+    return op.emitOpError("expected the expanded type, ")
+           << expandedType << " to have a higher (or same) rank "
+           << "than the collapsed type, " << collapsedType << '.';
+
   if (collapsedRank != op.getReassociation().size())
-    return op.emitOpError("expected rank of the collapsed type(")
-           << collapsedRank << ") to be the number of reassociation maps("
-           << op.getReassociation().size() << ")";
+    return op.emitOpError("expected collapsed rank (")
+           << collapsedRank << ") to equal the number of reassociation maps ("
+           << op.getReassociation().size() << ").";
+
   auto maps = op.getReassociationMaps();
   for (auto it : llvm::enumerate(maps))
     if (it.value().getNumDims() != expandedRank)
       return op.emitOpError("expected reassociation map #")
-             << it.index() << " of same rank as expanded memref("
-             << expandedRank << "), but got " << it.value().getNumDims();
+             << it.index() << " to have size equal to the expanded rank ("
+             << expandedRank << "), but it is  " << it.value().getNumDims()
+             << '.';
+
   int invalidIdx = 0;
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
-           << invalidIdx << " to be valid and contiguous";
-  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+           << invalidIdx << " to be valid and contiguous.";
+
+  return reshapeLikeShapesAreCompatible(
+      [&](const Twine &msg) { return op->emitOpError(msg); },
+      collapsedType.getShape(), expandedType.getShape(),
+      op.getReassociationIndices(), isExpansion);
+
 }
 
 /// Verify that shapes of the reshaped types using following rules
@@ -151,16 +154,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
 
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
-                                             ShapedType expandedType,
-                                             bool isExpandingReshape) {
-  return reshapeLikeShapesAreCompatible(
-      [&](const Twine &msg) { return op->emitOpError(msg); },
-      collapsedType.getShape(), expandedType.getShape(),
-      op.getReassociationIndices(), isExpandingReshape);
-}
-
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79022d675bd6..dc2ad9f54d0b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,15 +2447,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 591cbcde80084..ef8545016d3dc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,15 +1860,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 8c0b765c1b123..7a70e4737fb80 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -19,6 +19,16 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
 
 // -----
 
+// CHECK-LABEL: collapse_expand_rank0_cancel
+// CHECK-NEXT: return
+func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
+  %1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
+  return %1 : memref<1x1xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 53841fa836da7..6914d72c1ae30 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -448,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+// An (invalid) attempt at using collapse_shape to increase the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.collapse_shape %arg0 [[0], [0]] :
+    memref<?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+// An (invalid) attempt at using expand_shape to reduce the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
+    memref<2x3x1xf32> into memref<2x3xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.collapse_shape %arg0 [[1], [0]] :
+    memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
 func.func @collapse_shape_reshaping_non_contiguous(
     %arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
   // expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9b31d858d63c5..c227018fd6f08 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -10,6 +10,15 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
 
 // -----
 
+// CHECK-LABEL: expand_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: collapse_shape_identity_fold
 // CHECK-NEXT: return
 func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
@@ -19,6 +28,15 @@ func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf
 
 // -----
 
+// CHECK-LABEL: collapse_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+  %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: @tensor_bitcast_chain_ok
 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {



More information about the Mlir-commits mailing list