[Mlir-commits] [mlir] 10d7924 - Fix FoldReshapeOpWithUnitExtent generating illegal reshape
Ahmed Taei
llvmlistbot at llvm.org
Wed Apr 21 11:31:50 PDT 2021
Author: Ahmed Taei
Date: 2021-04-21T11:30:45-07:00
New Revision: 10d7924581f8f29c558d089c2546321de26f8849
URL: https://github.com/llvm/llvm-project/commit/10d7924581f8f29c558d089c2546321de26f8849
DIFF: https://github.com/llvm/llvm-project/commit/10d7924581f8f29c558d089c2546321de26f8849.diff
LOG: Fix FoldReshapeOpWithUnitExtent generating illegal reshape
This will prevent fusion that spains all dims and generates
(d0, d1, ...) -> () reshape that isn't legal
Differential Revision: https://reviews.llvm.org/D100805
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 15540596f75c6..5d8a664ef9646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -518,7 +518,16 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
} else {
return failure();
}
+
foldedDim++;
+ // If inner most dims are folded there shouldn't be any leading 1 dims.
+ // otherwise these dims are not mapped and will lead into an illegal
+ // reshape.
+ if (expandedDim == expandedShape.size()) {
+ if (foldedDim < foldedShape.size() && foldedShape[foldedDim] == 1) {
+ return failure();
+ }
+ }
}
if (expandedDim != expandedShape.size())
return failure();
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 8a36f6dce78ba..e9dd74faad64b 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -647,3 +647,21 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32>
+ return %1 : tensor<12x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: func @no_fold_reshape_empty_expr
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
+// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]]
+// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]]
+// CHECK: return %[[RES:.+]] : tensor<12x1xf32>
More information about the Mlir-commits
mailing list