[Mlir-commits] [mlir] [mlir][SCF] Allow canonicalization of zero-trip count `scf.forall` with empty mapping. (PR #105793)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 23 10:21:53 PDT 2024
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/105793
>From 66cee0ef36a1bb91b739f5a91aecbfb2389c45db Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 23 Aug 2024 00:15:06 -0700
Subject: [PATCH] [mlir][SCF] Allow canonicalization of zero-trip count
`scf.forall` with empty mapping.
Current folding of one-trip count loop does not kick in with an empty
mapping. Enable this for empty mapping.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 13 ++++++-----
mlir/test/Dialect/SCF/canonicalize.mlir | 27 ++++++++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 -----------------
3 files changed, 34 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e92d9503372cdf..bfa7db84bd9af7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
LogicalResult matchAndRewrite(ForallOp op,
PatternRewriter &rewriter) const override {
// Do not fold dimensions if they are mapped to processing units.
- if (op.getMapping().has_value())
+ if (op.getMapping().has_value() && !op.getMapping()->empty())
return failure();
Location loc = op.getLoc();
@@ -1729,11 +1729,6 @@ struct ForallOpSingleOrZeroIterationDimsFolder
newMixedUpperBounds.push_back(ub);
newMixedSteps.push_back(step);
}
- // Exit if none of the loop dimensions perform a single iteration.
- if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
- return rewriter.notifyMatchFailure(
- op, "no dimensions have 0 or 1 iterations");
- }
// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
@@ -1741,6 +1736,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
return success();
}
+ // Exit if none of the loop dimensions perform a single iteration.
+ if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+ return rewriter.notifyMatchFailure(
+ op, "no dimensions have 0 or 1 iterations");
+ }
+
// Replace the loop by a lower-dimensional loop.
ForallOp newOp;
newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 268946803de7a5..c68369a8e4fce7 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
// -----
+func.func @inline_empty_loop_with_empty_mapping(
+ %in: tensor<16xf32>) -> tensor<16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<16xf32>
+ %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
+ %slice = tensor.extract_slice %out_[0] [16] [1]
+ : tensor<16xf32> to tensor<16xf32>
+ %generic = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%slice : tensor<16xf32>) outs(%0 : tensor<16xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %2 = arith.addf %b0, %b0 : f32
+ linalg.yield %2 : f32
+ } -> tensor<16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
+ : tensor<16xf32> into tensor<16xf32>
+ }
+ }{ mapping = [] }
+ return %1 : tensor<16xf32>
+}
+// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
+// CHECK-NOT: scf.forall
+
+// -----
+
func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b8efde78cc23c..458ff51be7462e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2076,27 +2076,6 @@ func.func @canonicalize_parallel_insert_slice_indices(
// -----
-// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
-// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
-// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
-func.func @dont_fold_parallel_insert_slice(
- %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
-{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- // CHECK: scf.forall () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) {
- // CHECK-NEXT: scf.forall.in_parallel {
- // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
- %2 = scf.forall () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) {
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
- }
- }
- return %2 : tensor<1x5xf32>
-}
-
-// -----
-
// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
More information about the Mlir-commits
mailing list