[Mlir-commits] [mlir] [MLIR][Tensor, MemRef] Fold expand_shape and collapse_shape if identity (PR #80658)
James Newling
llvmlistbot at llvm.org
Thu Feb 29 11:52:04 PST 2024
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/80658
>From 847a899ab85d56b589a078a74b267cf2158cb1cf Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 01:48:47 -0800
Subject: [PATCH 1/5] let identity versions pass
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 --------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ------------
mlir/test/Dialect/MemRef/invalid.mlir | 14 --------------
mlir/test/Dialect/Tensor/invalid.mlir | 14 --------------
4 files changed, 48 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..79c91a5f8c6905 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2224,10 +2224,6 @@ LogicalResult ExpandShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
- if (srcType.getRank() >= resultType.getRank())
- return emitOpError("expected rank expansion, but found source rank ")
- << srcType.getRank() << " >= result rank " << resultType.getRank();
-
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
resultType.getShape(),
@@ -2378,10 +2374,6 @@ LogicalResult CollapseShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
- if (srcType.getRank() <= resultType.getRank())
- return emitOpError("expected rank reduction, but found source rank ")
- << srcType.getRank() << " <= result rank " << resultType.getRank();
-
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..4dec5d7280ce55 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1656,22 +1656,10 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
- if (srcType.getRank() >= resultType.getRank())
- return emitOpError("expected rank expansion, but found source rank ")
- << srcType.getRank() << " >= result rank " << resultType.getRank();
-
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
LogicalResult CollapseShapeOp::verify() {
- auto srcType = getSrcType();
- auto resultType = getResultType();
- if (srcType.getRank() <= resultType.getRank())
- return emitOpError("expected rank reduction, but found source rank ")
- << srcType.getRank() << " <= result rank " << resultType.getRank();
-
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 8f5ba5ea8fc78d..b7fab09f6f97df 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -415,20 +415,6 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
// -----
-func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
- // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
- %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
- // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
- %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
-}
-
-// -----
-
func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
// expected-error @+1 {{op reassociation index 2 is out of bounds}}
%0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 4c534fe936e3de..79ca0de68a1e9b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -343,20 +343,6 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
// -----
-func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
- // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
- %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
-func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
- // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
- %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
-}
-
-// -----
-
func.func @rank(%0: f32) {
// expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
"tensor.rank"(%0): (f32)->index
>From 7906997ac7195ef6498b82aa817e49bcfdde2f85 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 02:49:15 -0800
Subject: [PATCH 2/5] test
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 2 --
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++++
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 ++++++++++++++++++-
5 files changed, 47 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f272..3a672a1cc60601 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -111,8 +111,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
<< " to have higher rank than the type = " << collapsedType;
if (expandedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
- if (expandedRank == collapsedRank)
- return op.emitOpError("expected to collapse or expand dims");
if (collapsedRank == 0) {
// If collapsed rank is 0, then expanded type must be static shaped and of
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 79c91a5f8c6905..0fc4148057cb3e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,11 +2448,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4dec5d7280ce55..11d7f95f5943e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,11 +1860,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ if (getSrcType() == getType())
+ return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..cd057d0478aaae 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
+ %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
+ return %0 : memref<5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
+ %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
+ return %0 : memref<5x4xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_size_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..1fee7ee263b9da 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+
+// CHECK-LABEL: expand_shape_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
+ %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: collapse_shape_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
+ return %0 : tensor<5x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @tensor_bitcast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -2069,7 +2088,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
// Chain: NC -> NCnc -> NCnc -> NC
// CHECK: func.func @unpack_pack(
-// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
// CHECK: return %[[T]] : tensor<128x128xf32>
func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
%tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
>From 75ddea2c695000b2ec5d1ca7b54da5e83e93d998 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 05:32:48 -0800
Subject: [PATCH 3/5] updates. additional testing, doc corrections, remove
rank-0 special case handling
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 2 +-
.../mlir/Dialect/Tensor/IR/TensorOps.td | 39 +++++-------
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 63 +++++++++----------
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 --
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 --
mlir/test/Dialect/MemRef/canonicalize.mlir | 10 +++
mlir/test/Dialect/MemRef/invalid.mlir | 28 +++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 18 ++++++
8 files changed, 100 insertions(+), 68 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..39e66cd9e6e5ab 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
let summary = "non-blocking DMA operation that starts a transfer";
let description = [{
Syntax:
-
+
```
operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 1c61ece2676a90..670202fe4372e6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let summary = "operation to produce a tensor with a higher rank";
let description = [{
- The `tensor.expand_shape` op produces a new tensor with a higher
- rank whose sizes are a reassociation of the original `src`.
+ The `tensor.expand_shape` op produces a tensor of higher (or equal)
+ rank than the operand `src` whose dimension sizes are a reassociation of
+ `src`.
- A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of DenseI64ArrayAttr attribute.
-
- The verification rule is that the reassociation maps are applied to the
- result tensor with the higher rank to obtain the operand tensor with the
- smaller rank.
+ A reassociation is defined as a continuous grouping of dimensions. It is
+ represented with an array of DenseI64ArrayAttr attribute. Entries in the
+ array are referred to as reassociation maps.
- The operand tensor type of a reshape can be zero-ranked if the result
- tensor type is statically shaped with all dimensions being unit extent. In
- such cases the reassociation map is empty.
+ The reassociation maps are applied to the result shape to obtain the operand
+ shape.
- Examples:
+ Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
let description = [{
- The `tensor.collapse_shape` op produces a new tensor with a smaller
- rank whose sizes are a reassociation of the original `src`.
+ The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
+ rank whose dimension sizes are a reassociation of the original `src` dimensions.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of DenseI64ArrayAttr attribute.
+ represented by an array of DenseI64ArrayAttr attribute. The reassociation
+ maps are applied to the operand shape to obtain the result shape.
- The verification rule is that the reassociation maps are applied to the
- operand tensor with the higher rank to obtain the result tensor with the
- smaller rank.
- The result tensor type of a reshape can be zero-ranked if the operand
- tensor type is statically shaped with all dimensions being unit extent. In
- such case the reassociation map is empty.
-
- Examples:
+ Example:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
and optionally transposes the tiled source tensor dimensions.
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
- being tiled, where `0 < k <= n`. The order of the dimensions matters:
+ being tiled, where `0 < k <= n`. The order of the dimensions matters:
- The tiled dimensions (of size `inner_tiles`) are added to the end of the result
tensor in the order in which they appear in `inner_dims_pos`.
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a672a1cc60601..7192b5846cb3d9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
- // Fold producer-consumer reshape ops that where the operand type of the
+
+ if (reshapeOp.getSrcType() == reshapeOp.getType())
+ return reshapeOp.getSrc();
+
+ // Fold producer-consumer reshape ops where the operand type of the
// producer is same as the return type of the consumer.
auto reshapeSrcOp =
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.getSrc();
+
// Reshape of a constant can be replaced with a new constant.
- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
- }
+
return nullptr;
}
@@ -103,39 +108,37 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
template <typename Op, typename T>
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
T collapsedType, bool isExpansion) {
+
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
if (expandedRank < collapsedRank)
- return op.emitOpError("expected the type ")
- << expandedType
- << " to have higher rank than the type = " << collapsedType;
- if (expandedRank == 0)
- return op.emitOpError("expected non-zero memref ranks");
-
- if (collapsedRank == 0) {
- // If collapsed rank is 0, then expanded type must be static shaped and of
- // sizes 1.
- if (llvm::any_of(expandedType.getShape(),
- [](int64_t dim) -> bool { return dim != 1; }))
- return op.emitOpError("invalid to reshape tensor/memref with non-unit "
- "extent dimensions to zero-rank tensor/memref");
- return success();
- }
+ return op.emitOpError("expected the expanded type, ")
+ << expandedType << " to have a higher (or same) rank "
+ << "than the collapsed type, " << collapsedType << '.';
+
if (collapsedRank != op.getReassociation().size())
- return op.emitOpError("expected rank of the collapsed type(")
- << collapsedRank << ") to be the number of reassociation maps("
- << op.getReassociation().size() << ")";
+ return op.emitOpError("expected collapsed rank (")
+ << collapsedRank << ") to equal the number of reassociation maps ("
+ << op.getReassociation().size() << ").";
+
auto maps = op.getReassociationMaps();
for (auto it : llvm::enumerate(maps))
if (it.value().getNumDims() != expandedRank)
return op.emitOpError("expected reassociation map #")
- << it.index() << " of same rank as expanded memref("
- << expandedRank << "), but got " << it.value().getNumDims();
+ << it.index() << " to have size equal to the expanded rank ("
+ << expandedRank << "), but it is " << it.value().getNumDims()
+ << '.';
+
int invalidIdx = 0;
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
- << invalidIdx << " to be valid and contiguous";
- return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
+ << invalidIdx << " to be valid and contiguous.";
+
+ return reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) { return op->emitOpError(msg); },
+ collapsedType.getShape(), expandedType.getShape(),
+ op.getReassociationIndices(), isExpansion);
+
}
/// Verify that shapes of the reshaped types using following rules
@@ -151,16 +154,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
-template <typename OpTy>
-static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
- ShapedType expandedType,
- bool isExpandingReshape) {
- return reshapeLikeShapesAreCompatible(
- [&](const Twine &msg) { return op->emitOpError(msg); },
- collapsedType.getShape(), expandedType.getShape(),
- op.getReassociationIndices(), isExpandingReshape);
-}
-
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0fc4148057cb3e..79c91a5f8c6905 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,15 +2448,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 11d7f95f5943e2..4dec5d7280ce55 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1860,15 +1860,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
- if (getSrcType() == getType())
- return getSrc();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index cd057d0478aaae..b1e92e54d561da 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -19,6 +19,16 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
// -----
+// CHECK-LABEL: collapse_expand_rank0_cancel
+// CHECK-NEXT: return
+func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
+ %1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
+ return %1 : memref<1x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_size_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index b7fab09f6f97df..64c91cf40d0a1e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -448,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
// -----
+// An (invalid) attempt at using collapse_shape to increase the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.collapse_shape %arg0 [[0], [0]] :
+ memref<?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+// An (invalid) attempt at using expand_shape to reduce the rank might look
+// like this. Verify that a sensible error is emitted in this case.
+func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
+ memref<2x3x1xf32> into memref<2x3xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{reassociation indices must be contiguous}}
+ %0 = memref.collapse_shape %arg0 [[1], [0]] :
+ memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
func.func @collapse_shape_reshaping_non_contiguous(
%arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
// expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1fee7ee263b9da..6086a740aef671 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -10,6 +10,15 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
// -----
+// CHECK-LABEL: expand_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+ %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: collapse_shape_identity_fold
// CHECK-NEXT: return
func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
@@ -19,6 +28,15 @@ func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf
// -----
+// CHECK-LABEL: collapse_shape_rank0_identity_fold
+// CHECK-NEXT: return
+func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
+ %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: @tensor_bitcast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
>From 3684e452fe06bfd34b6694aa8c333fcfc582e83f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Feb 2024 06:04:44 -0800
Subject: [PATCH 4/5] clang-format
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 7192b5846cb3d9..ae9824f728da4d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -138,7 +138,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
[&](const Twine &msg) { return op->emitOpError(msg); },
collapsedType.getShape(), expandedType.getShape(),
op.getReassociationIndices(), isExpansion);
-
}
/// Verify that shapes of the reshaped types using following rules
>From e9088b01c5043448093d10e040842abb5a0fe184 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Sat, 24 Feb 2024 13:47:53 -0800
Subject: [PATCH 5/5] improved error message when rank change is not correct
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 16 ++++++++++++++++
mlir/test/Dialect/MemRef/invalid.mlir | 4 ++--
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 79c91a5f8c6905..94e0ed319cae83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2224,6 +2224,14 @@ LogicalResult ExpandShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
+ if (srcType.getRank() > resultType.getRank()) {
+ auto r0 = srcType.getRank();
+ auto r1 = resultType.getRank();
+ return emitOpError("has source rank ")
+ << r0 << " and result rank " << r1 << ". This is not an expansion ("
+ << r0 << " > " << r1 << ").";
+ }
+
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
resultType.getShape(),
@@ -2374,6 +2382,14 @@ LogicalResult CollapseShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
+ if (srcType.getRank() < resultType.getRank()) {
+ auto r0 = srcType.getRank();
+ auto r1 = resultType.getRank();
+ return emitOpError("has source rank ")
+ << r0 << " and result rank " << r1 << ". This is not a collapse ("
+ << r0 << " < " << r1 << ").";
+ }
+
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 64c91cf40d0a1e..1aef417549d9a1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -451,7 +451,7 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
// An (invalid) attempt at using collapse_shape to increase the rank might look
// like this. Verify that a sensible error is emitted in this case.
func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
- // expected-error @+1 {{reassociation indices must be contiguous}}
+ // expected-error @+1 {{'memref.collapse_shape' op has source rank 1 and result rank 2. This is not a collapse (1 < 2)}}
%0 = memref.collapse_shape %arg0 [[0], [0]] :
memref<?xf32> into memref<?x?xf32>
}
@@ -461,7 +461,7 @@ func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>)
// An (invalid) attempt at using expand_shape to reduce the rank might look
// like this. Verify that a sensible error is emitted in this case.
func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
- // expected-error @+1 {{reassociation indices must be contiguous}}
+ // expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}}
%0 = memref.expand_shape %arg0 [[0], [1], [1]] :
memref<2x3x1xf32> into memref<2x3xf32>
}
More information about the Mlir-commits
mailing list