[Mlir-commits] [mlir] [MLIR][Tensor, MemRef] Fold expand_shape and collapse_shape if identity (PR #80658)
James Newling
llvmlistbot at llvm.org
Mon Feb 5 05:33:02 PST 2024
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/80658
>From 4417550d0e1feadf8ce92bd180754c64fa78482c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 01:48:47 -0800
Subject: [PATCH 1/3] let identity versions pass
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 --------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ------------
mlir/test/Dialect/MemRef/invalid.mlir | 14 --------------
mlir/test/Dialect/Tensor/invalid.mlir | 14 --------------
4 files changed, 48 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e..dc2ad9f54d0b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2223,10 +2223,6 @@ LogicalResult ExpandShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
- if (srcType.getRank() >= resultType.getRank())
- return emitOpError("expected rank expansion, but found source rank ")
- << srcType.getRank() << " >= result rank " << resultType.getRank();
-
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
resultType.getShape(),
@@ -2377,10 +2373,6 @@ LogicalResult CollapseShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
- if (srcType.getRank() <= resultType.getRank())
- return emitOpError("expected rank reduction, but found source rank ")
- << srcType.getRank() << " <= result rank " << resultType.getRank();
-
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b2fe58099b2fb..ef8545016d3dc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1656,22 +1656,10 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
- if (srcType.getRank() >= resultType.getRank())
- return emitOpError("expected rank expansion, but found source rank ")
- << srcType.getRank() << " >= result rank " << resultType.getRank();
-
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
LogicalResult CollapseShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
- if (srcType.getRank() <= resultType.getRank())
- return emitOpError("expected rank reduction, but found source rank ")
- << srcType.getRank() << " <= result rank " << resultType.getRank();
-
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 7bb7a2affcbd1..53841fa836da7 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -415,20 +415,6 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
// -----
-func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
- // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
- %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
- // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
- %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
// expected-error @+1 {{op reassociation index 2 is out of bounds}}
%0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 735e5146e9dbc..39cd3788cb081 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -343,20 +343,6 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
// -----
-func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
- // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
- %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
- // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
- %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
func.func @rank(%0: f32) {
// expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
"tensor.rank"(%0): (f32)->index
>From f835843f2388477e345bfe7242d0371d9887d071 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 02:49:15 -0800
Subject: [PATCH 2/3] test
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 2 --
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++++
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 ++++++++++++++++++-
5 files changed, 47 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f27..3a672a1cc6060 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -111,8 +111,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
<< " to have higher rank than the type = " << collapsedType;
if (expandedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
- if (expandedRank == collapsedRank)
- return op.emitOpError("expected to collapse or expand dims");
if (collapsedRank == 0) {
// If collapsed rank is 0, then expanded type must be static shaped and of
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dc2ad9f54d0b1..b79022d675bd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,11 +2447,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ef8545016d3dc..591cbcde80084 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,11 +1860,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034..8c0b765c1b123 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
+ %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
+ return %0 : memref<5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
+ %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
+ return %0 : memref<5x4xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_size_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ed964071358ac..9b31d858d63c5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
+ %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
+ return %0 : tensor<5x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @tensor_bitcast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -1977,7 +1996,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
// Chain: NC -> NCnc -> NCnc -> NC
// CHECK: func.func @unpack_pack(
-// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
// CHECK: return %[[T]] : tensor<128x128xf32>
func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
%tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
>From 682a594bed444f609f21795ac5478f0b80503c8a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 05:32:48 -0800
Subject: [PATCH 3/3] updates. additional testing, doc corrections, remove
rank-0 special case handling
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 2 +-
.../mlir/Dialect/Tensor/IR/TensorOps.td | 39 +++++-------
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 63 +++++++++----------
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 --
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 --
mlir/test/Dialect/MemRef/canonicalize.mlir | 10 +++
mlir/test/Dialect/MemRef/invalid.mlir | 28 +++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 18 ++++++
8 files changed, 100 insertions(+), 68 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609..39e66cd9e6e5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
let summary = "non-blocking DMA operation that starts a transfer";
let description = [{
Syntax:
-
+
```
operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index eb0c79c01bee1..8134c83022eb3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let summary = "operation to produce a tensor with a higher rank";
let description = [{
- The `tensor.expand_shape` op produces a new tensor with a higher
- rank whose sizes are a reassociation of the original `src`.
+ The `tensor.expand_shape` op produces a tensor of higher (or equal)
+ rank than the operand `src` whose dimension sizes are a reassociation of
+ `src`.
- A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of DenseI64ArrayAttr attribute.
-
- The verification rule is that the reassociation maps are applied to the
- result tensor with the higher rank to obtain the operand tensor with the
- smaller rank.
+ A reassociation is defined as a continuous grouping of dimensions. It is
+ represented with an array of DenseI64ArrayAttr attribute. Entries in the
+ array are referred to as reassociation maps.
- The operand tensor type of a reshape can be zero-ranked if the result
- tensor type is statically shaped with all dimensions being unit extent. In
- such cases the reassociation map is empty.
+ The reassociation maps are applied to the result shape to obtain the operand
+ shape.
- Examples:
+ Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
let description = [{
- The `tensor.collapse_shape` op produces a new tensor with a smaller
- rank whose sizes are a reassociation of the original `src`.
+ The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
+ rank whose dimension sizes are a reassociation of the original `src` dimensions.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of DenseI64ArrayAttr attribute.
+ represented by an array of DenseI64ArrayAttr attribute. The reassociation
+ maps are applied to the operand shape to obtain the result shape.
- The verification rule is that the reassociation maps are applied to the
- operand tensor with the higher rank to obtain the result tensor with the
- smaller rank.
- The result tensor type of a reshape can be zero-ranked if the operand
- tensor type is statically shaped with all dimensions being unit extent. In
- such case the reassociation map is empty.
-
- Examples:
+ Example:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
and optionally transposes the tiled source tensor dimensions.
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
- being tiled, where `0 < k <= n`. The order of the dimensions matters:
+ being tiled, where `0 < k <= n`. The order of the dimensions matters:
- The tiled dimensions (of size `inner_tiles`) are added to the end of the result
tensor in the order in which they appear in `inner_dims_pos`.
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a672a1cc6060..7192b5846cb3d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
- // Fold producer-consumer reshape ops that where the operand type of the
+
+ if (reshapeOp.getSrcType() == reshapeOp.getType())
+ return reshapeOp.getSrc();
+
+ // Fold producer-consumer reshape ops where the operand type of the
// producer is same as the return type of the consumer.
auto reshapeSrcOp =
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.getSrc();
+
// Reshape of a constant can be replaced with a new constant.
- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
- }
+
return nullptr;
}
@@ -103,39 +108,37 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
template <typename Op, typename T>
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
T collapsedType, bool isExpansion) {
+
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
if (expandedRank < collapsedRank)
- return op.emitOpError("expected the type ")
- << expandedType
- << " to have higher rank than the type = " << collapsedType;
- if (expandedRank == 0)
- return op.emitOpError("expected non-zero memref ranks");
-
- if (collapsedRank == 0) {
- // If collapsed rank is 0, then expanded type must be static shaped and of
- // sizes 1.
- if (llvm::any_of(expandedType.getShape(),
- [](int64_t dim) -> bool { return dim != 1; }))
- return op.emitOpError("invalid to reshape tensor/memref with non-unit "
- "extent dimensions to zero-rank tensor/memref");
- return success();
- }
+ return op.emitOpError("expected the expanded type, ")
+ << expandedType << " to have a higher (or same) rank "
+ << "than the collapsed type, " << collapsedType << '.';
+
if (collapsedRank != op.getReassociation().size())
- return op.emitOpError("expected rank of the collapsed type(")
- << collapsedRank << ") to be the number of reassociation maps("
- << op.getReassociation().size() << ")";
+ return op.emitOpError("expected collapsed rank (")
+ << collapsedRank << ") to equal the number of reassociation maps ("
+ << op.getReassociation().size() << ").";
+
auto maps = op.getReassociationMaps();
for (auto it : llvm::enumerate(maps))
if (it.value().getNumDims() != expandedRank)
return op.emitOpError("expected reassociation map #")
- << it.index() << " of same rank as expanded memref("
- << expandedRank << "), but got " << it.value().getNumDims();
+ << it.index() << " to have size equal to the expanded rank ("
+ << expandedRank << "), but it is " << it.value().getNumDims()
+ << '.';
+
int invalidIdx = 0;
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
- << invalidIdx << " to be valid and contiguous";
- return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+ << invalidIdx << " to be valid and contiguous.";
+
+ return reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) { return op->emitOpError(msg); },
+ collapsedType.getShape(), expandedType.getShape(),
+ op.getReassociationIndices(), isExpansion);
+
}
/// Verify that shapes of the reshaped types using following rules
@@ -151,16 +154,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
- ShapedType expandedType,
- bool isExpandingReshape) {
- return reshapeLikeShapesAreCompatible(
- [&](const Twine &msg) { return op->emitOpError(msg); },
- collapsedType.getShape(), expandedType.getShape(),
- op.getReassociationIndices(), isExpandingReshape);
-}
-
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79022d675bd6..dc2ad9f54d0b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,15 +2447,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 591cbcde80084..ef8545016d3dc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,15 +1860,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 8c0b765c1b123..7a70e4737fb80 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -19,6 +19,16 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
// -----
+// CHECK-LABEL: collapse_expand_rank0_cancel
+// CHECK-NEXT: return
+func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
+ %1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
+ return %1 : memref<1x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_size_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 53841fa836da7..6914d72c1ae30 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -448,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
// -----
+// An (invalid) attempt at using collapse_shape to increase the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.collapse_shape %arg0 [[0], [0]] :
+ memref<?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+// An (invalid) attempt at using expand_shape to reduce the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
+ memref<2x3x1xf32> into memref<2x3xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.collapse_shape %arg0 [[1], [0]] :
+ memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
func.func @collapse_shape_reshaping_non_contiguous(
%arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
// expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9b31d858d63c5..c227018fd6f08 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -10,6 +10,15 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
// -----
+// CHECK-LABEL: expand_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+ %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: collapse_shape_identity_fold
// CHECK-NEXT: return
func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
@@ -19,6 +28,15 @@ func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf
// -----
+// CHECK-LABEL: collapse_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+ %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: @tensor_bitcast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
More information about the Mlir-commits
mailing list