[Mlir-commits] [mlir] 284279a - [MLIR] Add naive fusion of parallel loops.

Alexander Belyaev llvmlistbot at llvm.org
Wed Feb 19 05:51:18 PST 2020


Author: Alexander Belyaev
Date: 2020-02-19T14:51:09+01:00
New Revision: 284279ac23a944bf98911c13db2e745eb42e124b

URL: https://github.com/llvm/llvm-project/commit/284279ac23a944bf98911c13db2e745eb42e124b
DIFF: https://github.com/llvm/llvm-project/commit/284279ac23a944bf98911c13db2e745eb42e124b.diff

LOG: [MLIR] Add naive fusion of parallel loops.

Added: 
    mlir/include/mlir/Dialect/LoopOps/Passes.h
    mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
    mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
    mlir/test/Dialect/Loops/parallel-loop-fusion.mlir

Modified: 
    mlir/include/mlir/Dialect/LoopOps/LoopOps.td
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/LoopOps/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 67e9ed86d496..9462585dcf42 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -193,6 +193,7 @@ def ParallelOp : Loop_Op<"parallel",
     iterator_range<Block::args_iterator> getInductionVars() {
       return {getBody()->args_begin(), getBody()->args_end()};
     }
+    unsigned getNumLoops() { return step().size(); }
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h
new file mode 100644
index 000000000000..ea143857d804
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h
@@ -0,0 +1,27 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LOOPOPS_PASSES_H_
+#define MLIR_DIALECT_LOOPOPS_PASSES_H_
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+/// Creates a loop fusion pass which fuses parallel loops.
+std::unique_ptr<Pass> createParallelLoopFusionPass();
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LOOPOPS_PASSES_H_

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 6b28041bf980..6caabd74f9bb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/FxpMathOps/Passes.h"
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/LoopOps/Passes.h"
 #include "mlir/Dialect/QuantOps/Passes.h"
 #include "mlir/Dialect/SPIRV/Passes.h"
 #include "mlir/Quantizer/Transforms/Passes.h"
@@ -106,6 +107,9 @@ inline void registerAllPasses() {
   createConvertLinalgToAffineLoopsPass();
   createConvertLinalgToLLVMPass();
 
+  // LoopOps
+  createParallelLoopFusionPass();
+
   // QuantOps
   quant::createConvertSimulatedQuantPass();
   quant::createConvertConstPass();

diff  --git a/mlir/lib/Dialect/LoopOps/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/CMakeLists.txt
index 27f50068ec07..4e9ef44eebeb 100644
--- a/mlir/lib/Dialect/LoopOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/LoopOps/CMakeLists.txt
@@ -21,3 +21,5 @@ target_link_libraries(MLIRLoopOps
   MLIRStandardOps
   LLVMSupport
   )
+
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..b18c3656a2d9
--- /dev/null
+++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_llvm_library(MLIRLoopOpsTransforms
+  ParallelLoopFusion.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LoopOps
+  )
+
+target_link_libraries(MLIRLoopOpsTransforms
+  MLIRPass
+  MLIRLoopOps
+  )

diff  --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
new file mode 100644
index 000000000000..628423a65639
--- /dev/null
+++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
@@ -0,0 +1,182 @@
+//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop fusion on parallel loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/LoopOps/Passes.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using loop::ParallelOp;
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(ParallelOp ploop) {
+  auto walkResult =
+      ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
+  return walkResult.wasInterrupted();
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+                                 ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
+         matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
+         matchOperands(firstPloop.step(), secondPloop.step());
+}
+
+/// Returns true if the defining operation for the memref is inside the body
+/// of parallel loop.
+bool isDefinedInPloopBody(Value memref, ParallelOp ploop) {
+  auto *memrefDef = memref.getDefiningOp();
+  return memrefDef && ploop.getOperation()->isAncestor(memrefDef);
+}
+
+// Checks if the parallel loops have mixed access to the same buffers. Returns
+// `true` if the first parallel loop writes to the same indices that the second
+// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+    ParallelOp firstPloop, ParallelOp secondPloop,
+    const BlockAndValueMapping &firstToSecondPloopIndices) {
+  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  firstPloop.getBody()->walk([&](StoreOp store) {
+    bufferStores[store.getMemRef()].push_back(store.indices());
+  });
+  auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
+    // Stop if the memref is defined in secondPloop body. Careful alias analysis
+    // is needed.
+    auto *memrefDef = load.getMemRef().getDefiningOp();
+    if (memrefDef && memrefDef->getBlock() == load.getOperation()->getBlock())
+      return WalkResult::interrupt();
+
+    auto write = bufferStores.find(load.getMemRef());
+    if (write == bufferStores.end())
+      return WalkResult::advance();
+
+    // Allow only single write access per buffer.
+    if (write->second.size() != 1)
+      return WalkResult::interrupt();
+
+    // Check that the load indices of secondPloop coincide with store indices of
+    // firstPloop for the same memrefs.
+    auto storeIndices = write->second.front();
+    auto loadIndices = load.indices();
+    if (storeIndices.size() != loadIndices.size())
+      return WalkResult::interrupt();
+    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+          loadIndices[i])
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+                   const BlockAndValueMapping &firstToSecondPloopIndices) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
+                                            firstToSecondPloopIndices))
+    return failure();
+
+  BlockAndValueMapping secondToFirstPloopIndices;
+  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+                                firstPloop.getBody()->getArguments());
+  return success(haveNoReadsAfterWriteExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices));
+}
+
+static bool
+isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+              const BlockAndValueMapping &firstToSecondPloopIndices) {
+  return !hasNestedParallelOp(firstPloop) &&
+         !hasNestedParallelOp(secondPloop) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
+         succeeded(verifyDependencies(firstPloop, secondPloop,
+                                      firstToSecondPloopIndices));
+}
+
+/// Prepends operations of firstPloop's body into secondPloop's body.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+                        OpBuilder b) {
+  BlockAndValueMapping firstToSecondPloopIndices;
+  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
+                                secondPloop.getBody()->getArguments());
+
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
+    return;
+
+  b.setInsertionPointToStart(secondPloop.getBody());
+  for (auto &op : firstPloop.getBody()->without_terminator())
+    b.clone(op, firstToSecondPloopIndices);
+  firstPloop.erase();
+}
+
+static void naivelyFuseParallelOps(Operation *op) {
+  OpBuilder b(op);
+  // Consider every single block and attempt to fuse adjacent loops.
+  for (auto &region : op->getRegions()) {
+    for (auto &block : region.getBlocks()) {
+      SmallVector<SmallVector<ParallelOp, 8>, 1> ploop_chains{{}};
+      // Not using `walk()` to traverse only top-level parallel loops and also
+      // make sure that there are no side-effecting ops between the parallel
+      // loops.
+      bool noSideEffects = true;
+      for (auto &op : block.getOperations()) {
+        if (auto ploop = dyn_cast<ParallelOp>(op)) {
+          if (noSideEffects) {
+            ploop_chains.back().push_back(ploop);
+          } else {
+            ploop_chains.push_back({ploop});
+            noSideEffects = true;
+          }
+          continue;
+        }
+        noSideEffects &= op.hasNoSideEffect();
+      }
+      for (ArrayRef<ParallelOp> ploops : ploop_chains) {
+        llvm::errs() << "poo size = " << ploops.size() << '\n';
+        for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
+          fuseIfLegal(ploops[i], ploops[i + 1], b);
+      }
+    }
+  }
+}
+
+namespace {
+
+struct ParallelLoopFusion : public OperationPass<ParallelLoopFusion> {
+  void runOnOperation() override { naivelyFuseParallelOps(getOperation()); }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
+  return std::make_unique<ParallelLoopFusion>();
+}
+
+static PassRegistration<ParallelLoopFusion>
+    pass("parallel-loop-fusion", "Fuse adjacent parallel loops.");

diff  --git a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
new file mode 100644
index 000000000000..86919228c405
--- /dev/null
+++ b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
@@ -0,0 +1,309 @@
+// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-fusion)' -split-input-file | FileCheck %s --dump-input-on-failure
+
+func @fuse_empty_loops() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @fuse_empty_loops
+// CHECK:        [[C2:%.*]] = constant 2 : index
+// CHECK:        [[C0:%.*]] = constant 0 : index
+// CHECK:        [[C1:%.*]] = constant 1 : index
+// CHECK:        loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:          "loop.terminator"() : () -> ()
+// CHECK:        }
+// CHECK-NOT:    loop.parallel
+
+// -----
+
+func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
+                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %sum = alloc()  : memref<2x2xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = load %B[%i, %j] : memref<2x2xf32>
+    %C_elem = load %C[%i, %j] : memref<2x2xf32>
+    %sum_elem = addf %B_elem, %C_elem : f32
+    store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = mulf %sum_elem, %A_elem : f32
+    store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_two
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
+// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
+// CHECK:      [[C2:%.*]] = constant 2 : index
+// CHECK:      [[C0:%.*]] = constant 0 : index
+// CHECK:      [[C1:%.*]] = constant 1 : index
+// CHECK:      [[SUM:%.*]] = alloc()
+// CHECK:      loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
+// CHECK:        store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        "loop.terminator"() : () -> ()
+// CHECK:      }
+// CHECK:      dealloc [[SUM]]
+
+// -----
+
+func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
+                      %result: memref<100x10xf32>) {
+  %c100 = constant 100 : index
+  %c10 = constant 10 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %broadcast_rhs = alloc() : memref<100x10xf32>
+  %
diff  = alloc() : memref<100x10xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+    %rhs_elem = load %rhs[%i] : memref<100xf32>
+    store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+    %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32>
+    %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32>
+    %
diff _elem = subf %lhs_elem, %broadcast_rhs_elem : f32
+    store %
diff _elem, %
diff [%i, %j] : memref<100x10xf32>
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+    %
diff _elem = load %
diff [%i, %j] : memref<100x10xf32>
+    %exp_elem = exp %
diff _elem : f32
+    store %exp_elem, %result[%i, %j] : memref<100x10xf32>
+    "loop.terminator"() : () -> ()
+  }
+  dealloc %broadcast_rhs : memref<100x10xf32>
+  dealloc %
diff  : memref<100x10xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_three
+// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
+// CHECK-SAME:  [[RESULT:%.*]]: memref<100x10xf32>) {
+// CHECK:      [[C100:%.*]] = constant 100 : index
+// CHECK:      [[C10:%.*]] = constant 10 : index
+// CHECK:      [[C0:%.*]] = constant 0 : index
+// CHECK:      [[C1:%.*]] = constant 1 : index
+// CHECK:      [[BROADCAST_RHS:%.*]] = alloc()
+// CHECK:      [[DIFF:%.*]] = alloc()
+// CHECK:      loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]]
+// CHECK:        store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]]
+// CHECK:        [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
+// CHECK:        store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]]
+// CHECK:        store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        "loop.terminator"() : () -> ()
+// CHECK:      }
+// CHECK:      dealloc [[BROADCAST_RHS]]
+// CHECK:      dealloc [[DIFF]]
+
+// -----
+
+func @do_not_fuse_nested_ploop1() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      "loop.terminator"() : () -> ()
+    }
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_nested_ploop1
+// CHECK:        loop.parallel
+// CHECK:          loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_nested_ploop2() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      "loop.terminator"() : () -> ()
+    }
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_nested_ploop2
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+// CHECK:          loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_unmatching_num_loops() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i) = (%c0) to (%c2) step (%c1) {
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  %buffer  = alloc() : memref<2x2xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_unmatching_iteration_space() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c4 = constant 4 : index
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_unmatching_write_read_patterns(
+    %A: memref<2x2xf32>, %B: memref<2x2xf32>,
+    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %common_buf = alloc() : memref<2x2xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = load %B[%i, %j] : memref<2x2xf32>
+    %C_elem = load %C[%i, %j] : memref<2x2xf32>
+    %sum_elem = addf %B_elem, %C_elem : f32
+    store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %k = addi %i, %c1 : index
+    %sum_elem = load %common_buf[%k, %j] : memref<2x2xf32>
+    %A_elem = load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = mulf %sum_elem, %A_elem : f32
+    store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  dealloc %common_buf : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_unmatching_read_write_patterns(
+    %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %sum = alloc() : memref<2x2xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = load %B[%i, %j] : memref<2x2xf32>
+    %C_elem = load %common_buf[%i, %j] : memref<2x2xf32>
+    %sum_elem = addf %B_elem, %C_elem : f32
+    store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %k = addi %i, %c1 : index
+    %sum_elem = load %sum[%k, %j] : memref<2x2xf32>
+    %A_elem = load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = mulf %sum_elem, %A_elem : f32
+    store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
+    "loop.terminator"() : () -> ()
+  }
+  dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %buffer  = alloc() : memref<2x2xf32>
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    "loop.terminator"() : () -> ()
+  }
+  loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
+      : memref<2x2xf32> to memref<?x?xf32, offset: ?, strides:[?, ?]>
+    %A_elem = load %A[%i, %j] : memref<?x?xf32, offset: ?, strides:[?, ?]>
+    "loop.terminator"() : () -> ()
+  }
+  "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
+// CHECK:        loop.parallel
+// CHECK:        loop.parallel

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index eda365215cb1..fad641cece2d 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -19,6 +19,7 @@ target_link_libraries(MLIRMlirOptMain
 )
 
 set(LIBS
+  MLIRLoopOpsTransforms
   MLIRLoopAnalysis
   MLIRAnalysis
   MLIRAffineOps


        


More information about the Mlir-commits mailing list