[Mlir-commits] [mlir] f2d91a7 - [mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp

Matthias Springer llvmlistbot at llvm.org
Wed Nov 23 04:56:02 PST 2022


Author: Matthias Springer
Date: 2022-11-23T13:52:00+01:00
New Revision: f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4

URL: https://github.com/llvm/llvm-project/commit/f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4
DIFF: https://github.com/llvm/llvm-project/commit/f2d91a7ae1172d4cf8f06a0204a2dc64a0d68ea4.diff

LOG: [mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp

Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498.

Differential Revision: https://reviews.llvm.org/D138557

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index dba055d9fd992..760a5aab14489 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -225,7 +225,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -250,8 +250,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
     SmallVector<ReassociationIndices, 4> higherRankReassociation,
         lowerRankReassociation;
 
-    bool isResultCollapsed = srcRank > resultRank;
-    if (isResultCollapsed) {
+    if (srcRank > resultRank) {
       higherRankReassociation = expandOp.getReassociationIndices();
       lowerRankReassociation = collapseOp.getReassociationIndices();
     } else {
@@ -274,12 +273,20 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       }
       composedReassociation.push_back(composedIndices);
     }
-    if (isResultCollapsed)
+    if (srcRank > resultRank) {
       rewriter.replaceOpWithNewOp<CollapseOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
-    else
+    } else if (srcRank < resultRank) {
       rewriter.replaceOpWithNewOp<ExpandOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+    } else {
+      // Collapses/expansions that do not change the rank are not allowed. Use
+      // a cast instead.
+      assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
+             "expected same shape");
+      rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
+                                            expandOp.getSrc());
+    }
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2bbb57e6e0d28..503c8aed2709d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2447,7 +2447,7 @@ struct CollapseShapeOpMemRefCastFolder
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 36e3aadbc5982..23af46c6d7912 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1586,7 +1586,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results
       .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
            FoldReshapeWithConstant<CollapseShapeOp>,
            FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
           context);

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d34d0d95d0392..d9710b75f5b3e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -859,3 +859,19 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
   memref.store %v, %0[%i2] : memref<4xf32>
   return %src : memref<2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
+//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
+//       CHECK:   return %[[casted]]
+func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
+    -> (memref<?xf32, 3>)
+{
+  %0 = memref.expand_shape %m [[0, 1]]
+      : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
+  %1 = memref.collapse_shape %0 [[0, 1]]
+      : memref<1x?xf32, 3> into memref<?xf32, 3>
+  return %1 : memref<?xf32, 3>
+}

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c9e662f969d74..92e329d916a20 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1666,3 +1666,15 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
   %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+//       CHECK:   return %[[t]]
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+{
+  %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  return %1 : tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list