[Mlir-commits] [mlir] 36aac53 - [mlir][linalg] Extend drop unit dim pattern to all cases of reduction

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 17 10:12:17 PDT 2021


Author: thomasraoux
Date: 2021-09-17T10:09:57-07:00
New Revision: 36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5

URL: https://github.com/llvm/llvm-project/commit/36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5
DIFF: https://github.com/llvm/llvm-project/commit/36aac53b36983c5ae8b7dcb0519c34e8c41dc4e5.diff

LOG: [mlir][linalg] Extend drop unit dim pattern to all cases of reduction

Even with all parallel loops reading the output value is still allowed so we
don't have to handle reduction loops differently.

Differential Revision: https://reviews.llvm.org/D109851

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e23a58e50cf18..8315de4c72e7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -187,40 +187,13 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
       return failure();
     SmallVector<int64_t> dims = genericOp.getStaticShape();
 
-    // Find all the reduction iterators. Those need some special consideration
-    // (see below).
-    auto getLoopDimsOfType =
-        [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
-      SmallVector<AffineExpr> dimExprs;
-      getDimsOfType(genericOp, iteratorTypeName, dimExprs);
-      return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
-        return expr.cast<AffineDimExpr>().getPosition();
-      }));
-    };
-    auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
-
     DenseSet<unsigned> unitDims;
     SmallVector<unsigned, 4> unitDimsReductionLoops;
     ArrayAttr iteratorTypes = genericOp.iterator_types();
     for (auto expr : enumerate(invertedMap.getResults())) {
       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
-        if (dims[dimExpr.getPosition()] == 1) {
-          if (isParallelIterator(iteratorTypes[expr.index()]))
-            unitDims.insert(expr.index());
-          else if (isReductionIterator(iteratorTypes[expr.index()]))
-            unitDimsReductionLoops.push_back(expr.index());
-        }
-    }
-
-    // Reduction loops can be dropped if there is at least one other reduction
-    // loop that is not dropped. This accounts for the initial value read in the
-    // reduction loop.
-    if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
-      if (unitDimsReductionLoops.size() == reductionDims.size())
-        unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
-      else
-        unitDims.insert(unitDimsReductionLoops.begin(),
-                        unitDimsReductionLoops.end());
+        if (dims[dimExpr.getPosition()] == 1)
+          unitDims.insert(expr.index());
     }
 
     if (unitDims.empty())

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 0271638999a7a..60ad72300a185 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -361,7 +361,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
 
 // -----
 
-func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
+func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
   %cst = constant 1.000000e+00 : f32
   %c3 = constant 3 : index
   %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
@@ -378,17 +378,16 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
   } -> tensor<1x1xf32>
   return %3 : tensor<1x1xf32>
 }
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
-//      CHECK: func @unit_dim_for_reduction_keep_one
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+//      CHECK: func @unit_dim_for_both_reduction
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
-// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x1xf32>)
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["parallel"]
+// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<1xf32>)
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
 //      CHECK:   return %[[RESULT_RESHAPE]]


        


More information about the Mlir-commits mailing list