[Mlir-commits] [mlir] 10d7924 - Fix FoldReshapeOpWithUnitExtent generating illegal reshape

Ahmed Taei llvmlistbot at llvm.org
Wed Apr 21 11:31:50 PDT 2021


Author: Ahmed Taei
Date: 2021-04-21T11:30:45-07:00
New Revision: 10d7924581f8f29c558d089c2546321de26f8849

URL: https://github.com/llvm/llvm-project/commit/10d7924581f8f29c558d089c2546321de26f8849
DIFF: https://github.com/llvm/llvm-project/commit/10d7924581f8f29c558d089c2546321de26f8849.diff

LOG: Fix FoldReshapeOpWithUnitExtent generating illegal reshape

This will prevent fusion that spains all dims and generates
(d0, d1, ...) -> () reshape that isn't legal

Differential Revision: https://reviews.llvm.org/D100805

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 15540596f75c6..5d8a664ef9646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -518,7 +518,16 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
       } else {
         return failure();
       }
+
       foldedDim++;
+      // If inner most dims are folded there shouldn't be any leading 1 dims.
+      // otherwise these dims are not mapped and will lead into an illegal
+      // reshape.
+      if (expandedDim == expandedShape.size()) {
+        if (foldedDim < foldedShape.size() && foldedShape[foldedDim] == 1) {
+          return failure();
+        }
+      }
     }
     if (expandedDim != expandedShape.size())
       return failure();

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 8a36f6dce78ba..e9dd74faad64b 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -647,3 +647,21 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32>
+  return %1 : tensor<12x1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//      CHECK: func @no_fold_reshape_empty_expr
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
+//      CHECK:    %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]]
+//      CHECK:    %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]]
+//      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>


        


More information about the Mlir-commits mailing list