[Mlir-commits] [mlir] [mlir][scf] Considering affine.apply when fusing scf::ParallelOp (PR #80145)

Hsiangkai Wang llvmlistbot at llvm.org
Thu Feb 1 02:52:48 PST 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/80145

>From bb65dfcc064ca439b482b77cc5ebc4dc929b11fe Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 31 Jan 2024 12:59:15 +0000
Subject: [PATCH] [mlir][scf] Considering affine.apply when fusing
 scf::ParallelOp

When checking the load indices of the second loop coincide with the
store indices of the first loop, it only considers the index values are
the same or not. However, there are some cases the index values come
from affine.apply operator. In these cases, it will treat them as
different even the affine map is the same and the affine operands are
the same.

We already check if the iteration space is the same in isFusionLegal().
When checking affine.apply, we only need to consider the operands come
from the same induction variables. If so, we know the results of
affine.apply are the same.
---
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 27 +++++-
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 95 +++++++++++++++++++
 2 files changed, 120 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 8f2ab5f5e6dc1..d3dca1427e517 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
@@ -102,8 +103,30 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
       return WalkResult::interrupt();
     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i])
-        return WalkResult::interrupt();
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
     }
     return WalkResult::advance();
   });
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 110168ba6eca5..9c136bb635658 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -480,3 +480,98 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
 // CHECK:        scf.reduce
 // CHECK:      }
 // CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @fuse_same_indices_by_affine_apply(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %sum = memref.alloc()  : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x3xf32>
+  return
+}
+// CHECK:      #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fuse_same_indices_by_affine_apply
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
+// CHECK-NEXT:  scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:    %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
+// CHECK-NEXT:    memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
+// CHECK-NEXT:    %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
+// CHECK-NEXT:    %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
+// CHECK-NEXT:    %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
+// CHECK-NEXT:    memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    scf.reduce
+// CHECK-NEXT:  }
+// CHECK-NEXT:  memref.dealloc %[[ALLOC]] : memref<2x3xf32>
+// CHECK-NEXT:  return
+
+// -----
+
+func.func @do_not_fuse_affine_apply_to_non_ind_var(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %OffsetA: index, %OffsetB: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %sum = memref.alloc()  : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetA)
+    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetB)
+    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x3xf32>
+  return
+}
+// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
+// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:      %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
+// CHECK-NEXT:      memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
+// CHECK-NEXT:      scf.reduce
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:      %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
+// CHECK-NEXT:      %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
+// CHECK-NEXT:      %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
+// CHECK-NEXT:      memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      scf.reduce
+// CHECK-NEXT:    }
+// CHECK-NEXT:    memref.dealloc %[[ALLOC]] : memref<2x3xf32>
+// CHECK-NEXT:    return



More information about the Mlir-commits mailing list