[Mlir-commits] [mlir] f2d91a7 - [mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 23 04:56:02 PST 2022
Author: Matthias Springer
Date: 2022-11-23T13:52:00+01:00
New Revision: f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4
URL: https://github.com/llvm/llvm-project/commit/f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4
DIFF: https://github.com/llvm/llvm-project/commit/f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4.diff
LOG: [mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp
Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498.
Differential Revision: https://reviews.llvm.org/D138557
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index dba055d9fd992..760a5aab14489 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -225,7 +225,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -250,8 +250,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
SmallVector<ReassociationIndices, 4> higherRankReassociation,
lowerRankReassociation;
- bool isResultCollapsed = srcRank > resultRank;
- if (isResultCollapsed) {
+ if (srcRank > resultRank) {
higherRankReassociation = expandOp.getReassociationIndices();
lowerRankReassociation = collapseOp.getReassociationIndices();
} else {
@@ -274,12 +273,20 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
composedReassociation.push_back(composedIndices);
}
- if (isResultCollapsed)
+ if (srcRank > resultRank) {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
- else
+ } else if (srcRank < resultRank) {
rewriter.replaceOpWithNewOp<ExpandOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+ } else {
+ // Collapses/expansions that do not change the rank are not allowed. Use
+ // a cast instead.
+ assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
+ "expected same shape");
+ rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
+ expandOp.getSrc());
+ }
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2bbb57e6e0d28..503c8aed2709d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,7 +2447,7 @@ struct CollapseShapeOpMemRefCastFolder
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
CollapseShapeOpMemRefCastFolder>(context);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 36e3aadbc5982..23af46c6d7912 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1586,7 +1586,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d34d0d95d0392..d9710b75f5b3e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -859,3 +859,19 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
memref.store %v, %0[%i2] : memref<4xf32>
return %src : memref<2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
+// CHECK: return %[[casted]]
+func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
+ -> (memref<?xf32, 3>)
+{
+ %0 = memref.expand_shape %m [[0, 1]]
+ : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
+ %1 = memref.collapse_shape %0 [[0, 1]]
+ : memref<1x?xf32, 3> into memref<?xf32, 3>
+ return %1 : memref<?xf32, 3>
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c9e662f969d74..92e329d916a20 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1666,3 +1666,15 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
%1 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: return %[[t]]
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+{
+ %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
More information about the Mlir-commits
mailing list