[Mlir-commits] [mlir] [MLIR][Tensor, MemRef] Fold expand_shape and collapse_shape if identity (PR #80658)

James Newling llvmlistbot at llvm.org
Thu Feb 29 11:52:04 PST 2024


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/80658

>From 847a899ab85d56b589a078a74b267cf2158cb1cf Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 01:48:47 -0800
Subject: [PATCH 1/5] let identity versions pass

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp |  8 --------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ------------
 mlir/test/Dialect/MemRef/invalid.mlir    | 14 --------------
 mlir/test/Dialect/Tensor/invalid.mlir    | 14 --------------
 4 files changed, 48 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..79c91a5f8c6905 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2224,10 +2224,6 @@ LogicalResult ExpandShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
-  if (srcType.getRank() >= resultType.getRank())
-    return emitOpError("expected rank expansion, but found source rank ")
-           << srcType.getRank() << " >= result rank " << resultType.getRank();
-
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
                                   resultType.getShape(),
@@ -2378,10 +2374,6 @@ LogicalResult CollapseShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
-  if (srcType.getRank() <= resultType.getRank())
-    return emitOpError("expected rank reduction, but found source rank ")
-           << srcType.getRank() << " <= result rank " << resultType.getRank();
-
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
                                   srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..4dec5d7280ce55 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1656,22 +1656,10 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
-  if (srcType.getRank() >= resultType.getRank())
-    return emitOpError("expected rank expansion, but found source rank ")
-           << srcType.getRank() << " >= result rank " << resultType.getRank();
-
   return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
 }
 
 LogicalResult CollapseShapeOp::verify() {
-  auto srcType = getSrcType();
-  auto resultType = getResultType();
-  if (srcType.getRank() <= resultType.getRank())
-    return emitOpError("expected rank reduction, but found source rank ")
-           << srcType.getRank() << " <= result rank " << resultType.getRank();
-
   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 8f5ba5ea8fc78d..b7fab09f6f97df 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -415,20 +415,6 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
-  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
-  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
-  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
-  %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
 func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
   // expected-error @+1 {{op reassociation index 2 is out of bounds}}
   %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 4c534fe936e3de..79ca0de68a1e9b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -343,20 +343,6 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
 
 // -----
 
-func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
-  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
-  %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
-  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
-  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
 func.func @rank(%0: f32) {
   // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
   "tensor.rank"(%0): (f32)->index

>From 7906997ac7195ef6498b82aa817e49bcfdde2f85 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 02:49:15 -0800
Subject: [PATCH 2/5] test

---
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  2 --
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  4 ++++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 ++++
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 19 +++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 21 ++++++++++++++++++-
 5 files changed, 47 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f272..3a672a1cc60601 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -111,8 +111,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
            << " to have higher rank than the type = " << collapsedType;
   if (expandedRank == 0)
     return op.emitOpError("expected non-zero memref ranks");
-  if (expandedRank == collapsedRank)
-    return op.emitOpError("expected to collapse or expand dims");
 
   if (collapsedRank == 0) {
     // If collapsed rank is 0, then expanded type must be static shaped and of
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 79c91a5f8c6905..0fc4148057cb3e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,11 +2448,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4dec5d7280ce55..11d7f95f5943e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,11 +1860,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  if (getSrcType() == getType())
+    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..cd057d0478aaae 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,5 +1,24 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
 
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
+  %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
+  return %0 : memref<5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
+  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
+  return %0 : memref<5x4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..1fee7ee263b9da 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,5 +1,24 @@
 // RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
 
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
+  return %0 : tensor<5x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @tensor_bitcast_chain_ok
 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -2069,7 +2088,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
 
 // Chain: NC -> NCnc -> NCnc -> NC
 // CHECK: func.func @unpack_pack(
-// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, 
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
 // CHECK: return %[[T]] : tensor<128x128xf32>
 func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
   %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>

>From 75ddea2c695000b2ec5d1ca7b54da5e83e93d998 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 05:32:48 -0800
Subject: [PATCH 3/5] updates. additional testing, doc corrections, remove
 rank-0 special case handling

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 39 +++++-------
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      | 63 +++++++++----------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  4 --
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 --
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 10 +++
 mlir/test/Dialect/MemRef/invalid.mlir         | 28 +++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 18 ++++++
 8 files changed, 100 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..39e66cd9e6e5ab 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
   let summary = "non-blocking DMA operation that starts a transfer";
   let description = [{
     Syntax:
-    
+
     ```
     operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
                    ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 1c61ece2676a90..670202fe4372e6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
   let summary = "operation to produce a tensor with a higher rank";
   let description = [{
-    The `tensor.expand_shape` op produces a new tensor with a higher
-    rank whose sizes are a reassociation of the original `src`.
+    The `tensor.expand_shape` op produces a tensor of higher (or equal)
+    rank than the operand `src` whose dimension sizes are a reassociation of
+    `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.
-
-    The verification rule is that the reassociation maps are applied to the
-    result tensor with the higher rank to obtain the operand tensor with the
-    smaller rank.
+    A reassociation is defined as a continuous grouping of dimensions. It is
+    represented with an array of DenseI64ArrayAttr attribute. Entries in the
+    array are referred to as reassociation maps.
 
-    The operand tensor type of a reshape can be zero-ranked if the result
-    tensor type is statically shaped with all dimensions being unit extent. In
-    such cases the reassociation map is empty.
+    The reassociation maps are applied to the result shape to obtain the operand
+    shape.
 
-    Examples:
+    Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
   let description = [{
-    The `tensor.collapse_shape` op produces a new tensor with a smaller
-    rank whose sizes are a reassociation of the original `src`.
+    The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
+    rank whose dimension sizes are a reassociation of the original `src` dimensions.
 
     A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.
+    represented by an array of DenseI64ArrayAttr attribute. The reassociation
+    maps are applied to the operand shape to obtain the result shape.
 
-    The verification rule is that the reassociation maps are applied to the
-    operand tensor with the higher rank to obtain the result tensor with the
-    smaller rank.
 
-    The result tensor type of a reshape can be zero-ranked if the operand
-    tensor type is statically shaped with all dimensions being unit extent. In
-    such case the reassociation map is empty.
-
-    Examples:
+    Example:
 
     ```mlir
     // Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     and optionally transposes the tiled source tensor dimensions.
 
     `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
-    being tiled, where `0 < k <= n`. The order of the dimensions matters: 
+    being tiled, where `0 < k <= n`. The order of the dimensions matters:
      - The tiled dimensions (of size `inner_tiles`) are added to the end of the result
     tensor in the order in which they appear in `inner_dims_pos`.
      - `inner_dims_pos[i]` specifies the source tensor dimension tiled by
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a672a1cc60601..7192b5846cb3d9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-  // Fold producer-consumer reshape ops that where the operand type of the
+
+  if (reshapeOp.getSrcType() == reshapeOp.getType())
+    return reshapeOp.getSrc();
+
+  // Fold producer-consumer reshape ops where the operand type of the
   // producer is same as the return type of the consumer.
   auto reshapeSrcOp =
       reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
   if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.getSrc();
+
   // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
-  }
+
   return nullptr;
 }
 
@@ -103,39 +108,37 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
 template <typename Op, typename T>
 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
                                             T collapsedType, bool isExpansion) {
+
   unsigned expandedRank = expandedType.getRank();
   unsigned collapsedRank = collapsedType.getRank();
   if (expandedRank < collapsedRank)
-    return op.emitOpError("expected the type ")
-           << expandedType
-           << " to have higher rank than the type = " << collapsedType;
-  if (expandedRank == 0)
-    return op.emitOpError("expected non-zero memref ranks");
-
-  if (collapsedRank == 0) {
-    // If collapsed rank is 0, then expanded type must be static shaped and of
-    // sizes 1.
-    if (llvm::any_of(expandedType.getShape(),
-                     [](int64_t dim) -> bool { return dim != 1; }))
-      return op.emitOpError("invalid to reshape tensor/memref with non-unit "
-                            "extent dimensions to zero-rank tensor/memref");
-    return success();
-  }
+    return op.emitOpError("expected the expanded type, ")
+           << expandedType << " to have a higher (or same) rank "
+           << "than the collapsed type, " << collapsedType << '.';
+
   if (collapsedRank != op.getReassociation().size())
-    return op.emitOpError("expected rank of the collapsed type(")
-           << collapsedRank << ") to be the number of reassociation maps("
-           << op.getReassociation().size() << ")";
+    return op.emitOpError("expected collapsed rank (")
+           << collapsedRank << ") to equal the number of reassociation maps ("
+           << op.getReassociation().size() << ").";
+
   auto maps = op.getReassociationMaps();
   for (auto it : llvm::enumerate(maps))
     if (it.value().getNumDims() != expandedRank)
       return op.emitOpError("expected reassociation map #")
-             << it.index() << " of same rank as expanded memref("
-             << expandedRank << "), but got " << it.value().getNumDims();
+             << it.index() << " to have size equal to the expanded rank ("
+             << expandedRank << "), but it is  " << it.value().getNumDims()
+             << '.';
+
   int invalidIdx = 0;
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
-           << invalidIdx << " to be valid and contiguous";
-  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+           << invalidIdx << " to be valid and contiguous.";
+
+  return reshapeLikeShapesAreCompatible(
+      [&](const Twine &msg) { return op->emitOpError(msg); },
+      collapsedType.getShape(), expandedType.getShape(),
+      op.getReassociationIndices(), isExpansion);
+
 }
 
 /// Verify that shapes of the reshaped types using following rules
@@ -151,16 +154,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
 
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
-                                             ShapedType expandedType,
-                                             bool isExpandingReshape) {
-  return reshapeLikeShapesAreCompatible(
-      [&](const Twine &msg) { return op->emitOpError(msg); },
-      collapsedType.getShape(), expandedType.getShape(),
-      op.getReassociationIndices(), isExpandingReshape);
-}
-
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0fc4148057cb3e..79c91a5f8c6905 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,15 +2448,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 11d7f95f5943e2..4dec5d7280ce55 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,15 +1860,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
                                                        adaptor.getOperands());
 }
 
 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
-  if (getSrcType() == getType())
-    return getSrc();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
                                                        adaptor.getOperands());
 }
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index cd057d0478aaae..b1e92e54d561da 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -19,6 +19,16 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
 
 // -----
 
+// CHECK-LABEL: collapse_expand_rank0_cancel
+// CHECK-NEXT: return
+func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
+  %1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
+  return %1 : memref<1x1xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index b7fab09f6f97df..64c91cf40d0a1e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -448,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+// An (invalid) attempt at using collapse_shape to increase the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.collapse_shape %arg0 [[0], [0]] :
+    memref<?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+// An (invalid) attempt at using expand_shape to reduce the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
+    memref<2x3x1xf32> into memref<2x3xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{reassociation indices must be contiguous}}
+  %0 = memref.collapse_shape %arg0 [[1], [0]] :
+    memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
 func.func @collapse_shape_reshaping_non_contiguous(
     %arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
   // expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1fee7ee263b9da..6086a740aef671 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -10,6 +10,15 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
 
 // -----
 
+// CHECK-LABEL: expand_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: collapse_shape_identity_fold
 // CHECK-NEXT: return
 func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
@@ -19,6 +28,15 @@ func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf
 
 // -----
 
+// CHECK-LABEL: collapse_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+  %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: @tensor_bitcast_chain_ok
 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {

>From 3684e452fe06bfd34b6694aa8c333fcfc582e83f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 06:04:44 -0800
Subject: [PATCH 4/5] clang-format

---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 7192b5846cb3d9..ae9824f728da4d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -138,7 +138,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       [&](const Twine &msg) { return op->emitOpError(msg); },
       collapsedType.getShape(), expandedType.getShape(),
       op.getReassociationIndices(), isExpansion);
-
 }
 
 /// Verify that shapes of the reshaped types using following rules

>From e9088b01c5043448093d10e040842abb5a0fe184 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Sat, 24 Feb 2024 13:47:53 -0800
Subject: [PATCH 5/5] improved error message when rank change is not correct

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 16 ++++++++++++++++
 mlir/test/Dialect/MemRef/invalid.mlir    |  4 ++--
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 79c91a5f8c6905..94e0ed319cae83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2224,6 +2224,14 @@ LogicalResult ExpandShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
+  if (srcType.getRank() > resultType.getRank()) {
+    auto r0 = srcType.getRank();
+    auto r1 = resultType.getRank();
+    return emitOpError("has source rank ")
+           << r0 << " and result rank " << r1 << ". This is not an expansion ("
+           << r0 << " > " << r1 << ").";
+  }
+
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
                                   resultType.getShape(),
@@ -2374,6 +2382,14 @@ LogicalResult CollapseShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
+  if (srcType.getRank() < resultType.getRank()) {
+    auto r0 = srcType.getRank();
+    auto r1 = resultType.getRank();
+    return emitOpError("has source rank ")
+           << r0 << " and result rank " << r1 << ". This is not a collapse ("
+           << r0 << " < " << r1 << ").";
+  }
+
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
                                   srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 64c91cf40d0a1e..1aef417549d9a1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -451,7 +451,7 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
 // An (invalid) attempt at using collapse_shape to increase the rank might look
 // like this. Verify that a sensible error is emitted in this case.
 func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{reassociation indices must be contiguous}}
+  // expected-error @+1 {{'memref.collapse_shape' op has source rank 1 and result rank 2. This is not a collapse (1 < 2)}}
   %0 = memref.collapse_shape %arg0 [[0], [0]] :
     memref<?xf32> into memref<?x?xf32>
 }
@@ -461,7 +461,7 @@ func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>)
 // An (invalid) attempt at using expand_shape to reduce the rank might look
 // like this. Verify that a sensible error is emitted in this case.
 func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
-  // expected-error @+1 {{reassociation indices must be contiguous}}
+  // expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}}
   %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
     memref<2x3x1xf32> into memref<2x3xf32>
 }



More information about the Mlir-commits mailing list