[Mlir-commits] [mlir] [mlir][SCF] Allow canonicalization of zero-trip count `scf.forall` with empty mapping. (PR #105793)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 23 10:21:53 PDT 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/105793

>From 66cee0ef36a1bb91b739f5a91aecbfb2389c45db Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 23 Aug 2024 00:15:06 -0700
Subject: [PATCH] [mlir][SCF] Allow canonicalization of zero-trip count
 `scf.forall` with empty mapping.

Current folding of one-trip count loop does not kick in with an empty
mapping. Enable this for empty mapping.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 13 ++++++-----
 mlir/test/Dialect/SCF/canonicalize.mlir    | 27 ++++++++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 21 -----------------
 3 files changed, 34 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e92d9503372cdf..bfa7db84bd9af7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   LogicalResult matchAndRewrite(ForallOp op,
                                 PatternRewriter &rewriter) const override {
     // Do not fold dimensions if they are mapped to processing units.
-    if (op.getMapping().has_value())
+    if (op.getMapping().has_value() && !op.getMapping()->empty())
       return failure();
     Location loc = op.getLoc();
 
@@ -1729,11 +1729,6 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       newMixedUpperBounds.push_back(ub);
       newMixedSteps.push_back(step);
     }
-    // Exit if none of the loop dimensions perform a single iteration.
-    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
-      return rewriter.notifyMatchFailure(
-          op, "no dimensions have 0 or 1 iterations");
-    }
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
@@ -1741,6 +1736,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       return success();
     }
 
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+      return rewriter.notifyMatchFailure(
+          op, "no dimensions have 0 or 1 iterations");
+    }
+
     // Replace the loop by a lower-dimensional loop.
     ForallOp newOp;
     newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 268946803de7a5..c68369a8e4fce7 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
 
 // -----
 
+func.func @inline_empty_loop_with_empty_mapping(
+    %in: tensor<16xf32>) -> tensor<16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<16xf32>
+  %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
+    %slice = tensor.extract_slice %out_[0] [16] [1]
+      : tensor<16xf32> to tensor<16xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%slice : tensor<16xf32>) outs(%0 : tensor<16xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %2 = arith.addf %b0, %b0 : f32
+        linalg.yield %2 : f32
+    } -> tensor<16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
+        : tensor<16xf32> into tensor<16xf32>
+    }
+  }{ mapping = [] }
+  return %1 : tensor<16xf32>
+}
+// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
+//   CHECK-NOT:   scf.forall
+
+// -----
+
 func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %c8 = arith.constant 8 : index
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b8efde78cc23c..458ff51be7462e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2076,27 +2076,6 @@ func.func @canonicalize_parallel_insert_slice_indices(
 
 // -----
 
-// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
-//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
-//  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
-func.func @dont_fold_parallel_insert_slice(
-    %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
-{
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  //      CHECK: scf.forall () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) {
-  // CHECK-NEXT:   scf.forall.in_parallel {
-  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
-  %2 = scf.forall () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) {
-    scf.forall.in_parallel {
-      tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
-    }
-  }
-  return %2 : tensor<1x5xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
 func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {



More information about the Mlir-commits mailing list