[Mlir-commits] [mlir] c3eb297 - [mlir][scf] Considering defining operators of indices when fusing scf::ParallelOp (#80145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 1 05:57:36 PST 2024


Author: Hsiangkai Wang
Date: 2024-02-01T13:57:31Z
New Revision: c3eb2978a60b4e2e0cf9c8a8f9c51b48bd49477a

URL: https://github.com/llvm/llvm-project/commit/c3eb2978a60b4e2e0cf9c8a8f9c51b48bd49477a
DIFF: https://github.com/llvm/llvm-project/commit/c3eb2978a60b4e2e0cf9c8a8f9c51b48bd49477a.diff

LOG: [mlir][scf] Considering defining operators of indices when fusing scf::ParallelOp (#80145)

When checking the load indices of the second loop coincide with the
store indices of the first loop, it only considers the index values are
the same or not. However, there are some cases the index values defined
by other operators. In these cases, it will treat them as different even
the results of defining operators are the same.

We already check if the iteration space is the same in isFusionLegal().
When checking operands of defining operators, we only need to consider
the operands come from the same induction variables. If so, we know the
results of defining operators are the same.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
    mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 8f2ab5f5e6dc1..d3dca1427e517 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
@@ -102,8 +103,30 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
       return WalkResult::interrupt();
     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i])
-        return WalkResult::interrupt();
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
     }
     return WalkResult::advance();
   });

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 110168ba6eca5..9c136bb635658 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -480,3 +480,98 @@ func.func @do_not_fuse_multiple_stores_on_
diff _indices(
 // CHECK:        scf.reduce
 // CHECK:      }
 // CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @fuse_same_indices_by_affine_apply(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %sum = memref.alloc()  : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x3xf32>
+  return
+}
+// CHECK:      #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fuse_same_indices_by_affine_apply
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
+// CHECK-NEXT:  scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:    %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
+// CHECK-NEXT:    memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
+// CHECK-NEXT:    %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
+// CHECK-NEXT:    %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
+// CHECK-NEXT:    %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
+// CHECK-NEXT:    memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
+// CHECK-NEXT:    scf.reduce
+// CHECK-NEXT:  }
+// CHECK-NEXT:  memref.dealloc %[[ALLOC]] : memref<2x3xf32>
+// CHECK-NEXT:  return
+
+// -----
+
+func.func @do_not_fuse_affine_apply_to_non_ind_var(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %OffsetA: index, %OffsetB: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %sum = memref.alloc()  : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetA)
+    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetB)
+    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x3xf32>
+  return
+}
+// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
+// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:      %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
+// CHECK-NEXT:      memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
+// CHECK-NEXT:      scf.reduce
+// CHECK-NEXT:    }
+// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
+// CHECK-NEXT:      %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
+// CHECK-NEXT:      %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
+// CHECK-NEXT:      %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
+// CHECK-NEXT:      memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
+// CHECK-NEXT:      scf.reduce
+// CHECK-NEXT:    }
+// CHECK-NEXT:    memref.dealloc %[[ALLOC]] : memref<2x3xf32>
+// CHECK-NEXT:    return


        


More information about the Mlir-commits mailing list