[Mlir-commits] [mlir] 36aac53 - [mlir][linalg] Extend drop unit dim pattern to all cases of reduction
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 17 10:12:17 PDT 2021
Author: thomasraoux
Date: 2021-09-17T10:09:57-07:00
New Revision: 36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5
URL: https://github.com/llvm/llvm-project/commit/36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5
DIFF: https://github.com/llvm/llvm-project/commit/36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5.diff
LOG: [mlir][linalg] Extend drop unit dim pattern to all cases of reduction
Even with all parallel loops reading the output value is still allowed so we
don't have to handle reduction loops differently.
Differential Revision: https://reviews.llvm.org/D109851
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e23a58e50cf18..8315de4c72e7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -187,40 +187,13 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
return failure();
SmallVector<int64_t> dims = genericOp.getStaticShape();
- // Find all the reduction iterators. Those need some special consideration
- // (see below).
- auto getLoopDimsOfType =
- [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
- SmallVector<AffineExpr> dimExprs;
- getDimsOfType(genericOp, iteratorTypeName, dimExprs);
- return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
- }));
- };
- auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
-
DenseSet<unsigned> unitDims;
SmallVector<unsigned, 4> unitDimsReductionLoops;
ArrayAttr iteratorTypes = genericOp.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
- if (dims[dimExpr.getPosition()] == 1) {
- if (isParallelIterator(iteratorTypes[expr.index()]))
- unitDims.insert(expr.index());
- else if (isReductionIterator(iteratorTypes[expr.index()]))
- unitDimsReductionLoops.push_back(expr.index());
- }
- }
-
- // Reduction loops can be dropped if there is at least one other reduction
- // loop that is not dropped. This accounts for the initial value read in the
- // reduction loop.
- if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
- if (unitDimsReductionLoops.size() == reductionDims.size())
- unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
- else
- unitDims.insert(unitDimsReductionLoops.begin(),
- unitDimsReductionLoops.end());
+ if (dims[dimExpr.getPosition()] == 1)
+ unitDims.insert(expr.index());
}
if (unitDims.empty())
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 0271638999a7a..60ad72300a185 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -361,7 +361,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
// -----
-func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
+func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
%cst = constant 1.000000e+00 : f32
%c3 = constant 3 : index
%1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
@@ -378,17 +378,16 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
} -> tensor<1x1xf32>
return %3 : tensor<1x1xf32>
}
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK: func @unit_dim_for_reduction_keep_one
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @unit_dim_for_both_reduction
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32>
-// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
// CHECK: %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
-// CHECK-SAME: iterator_types = ["parallel", "reduction"]
-// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x1xf32>)
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
More information about the Mlir-commits
mailing list