[Mlir-commits] [mlir] [mlir][SCF] Allow canonicalization of zero-trip count `scf.forall` with empty mapping. (PR #105793)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 23 00:19:08 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/105793

Current folding of one-trip count loop does not kick in with an empty mapping. Enable this for empty mapping.

>From 1094b9bed552d0d603f79788f99a4e89b2abc850 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 23 Aug 2024 00:15:06 -0700
Subject: [PATCH] [mlir][SCF] Allow canonicalization of zero-trip count
 `scf.forall` with empty mapping.

Current folding of one-trip count loop does not kick in with an empty
mapping. Enable this for empty mapping.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 13 ++++++------
 mlir/test/Dialect/SCF/canonicalize.mlir | 27 +++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e92d9503372cdf..bfa7db84bd9af7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   LogicalResult matchAndRewrite(ForallOp op,
                                 PatternRewriter &rewriter) const override {
     // Do not fold dimensions if they are mapped to processing units.
-    if (op.getMapping().has_value())
+    if (op.getMapping().has_value() && !op.getMapping()->empty())
       return failure();
     Location loc = op.getLoc();
 
@@ -1729,11 +1729,6 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       newMixedUpperBounds.push_back(ub);
       newMixedSteps.push_back(step);
     }
-    // Exit if none of the loop dimensions perform a single iteration.
-    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
-      return rewriter.notifyMatchFailure(
-          op, "no dimensions have 0 or 1 iterations");
-    }
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
@@ -1741,6 +1736,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       return success();
     }
 
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+      return rewriter.notifyMatchFailure(
+          op, "no dimensions have 0 or 1 iterations");
+    }
+
     // Replace the loop by a lower-dimensional loop.
     ForallOp newOp;
     newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 268946803de7a5..ff7fafac42cb4a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
 
 // -----
 
+func.func @inline_empty_loop_with_empty_mapping(
+    %in: tensor<16xf32>) -> tensor<16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<16xf32>
+  %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
+    %slice = tensor.extract_slice %out_[0] [16] [1]
+      : tensor<16xf32> to tensor<16xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%slice : tensor<16xf32>) outs(%out_0 : tensor<16xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %2 = arith.addf %b0, %b0 : f32
+        linalg.yield %2 : f32
+    } -> tensor<16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
+        : tensor<16xf32> into tensor<16xf32>
+    }
+  }{ mapping = [] }
+  return %1 : tensor<16xf32>
+}
+// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
+//   CHECK-NOT:   scf.forall
+
+// -----
+
 func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %c8 = arith.constant 8 : index
   %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list