[Mlir-commits] [mlir] 284279a - [MLIR] Add naive fusion of parallel loops.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Feb 19 05:51:18 PST 2020
Author: Alexander Belyaev
Date: 2020-02-19T14:51:09+01:00
New Revision: 284279ac23a944bf98911c13db2e745eb42e124b
URL: https://github.com/llvm/llvm-project/commit/284279ac23a944bf98911c13db2e745eb42e124b
DIFF: https://github.com/llvm/llvm-project/commit/284279ac23a944bf98911c13db2e745eb42e124b.diff
LOG: [MLIR] Add naive fusion of parallel loops.
Added:
mlir/include/mlir/Dialect/LoopOps/Passes.h
mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
Modified:
mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/LoopOps/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 67e9ed86d496..9462585dcf42 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -193,6 +193,7 @@ def ParallelOp : Loop_Op<"parallel",
iterator_range<Block::args_iterator> getInductionVars() {
return {getBody()->args_begin(), getBody()->args_end()};
}
+ unsigned getNumLoops() { return step().size(); }
}];
}
diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h
new file mode 100644
index 000000000000..ea143857d804
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h
@@ -0,0 +1,27 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LOOPOPS_PASSES_H_
+#define MLIR_DIALECT_LOOPOPS_PASSES_H_
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+/// Creates a loop fusion pass which fuses parallel loops.
+std::unique_ptr<Pass> createParallelLoopFusionPass();
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LOOPOPS_PASSES_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 6b28041bf980..6caabd74f9bb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -27,6 +27,7 @@
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/LoopOps/Passes.h"
#include "mlir/Dialect/QuantOps/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Quantizer/Transforms/Passes.h"
@@ -106,6 +107,9 @@ inline void registerAllPasses() {
createConvertLinalgToAffineLoopsPass();
createConvertLinalgToLLVMPass();
+ // LoopOps
+ createParallelLoopFusionPass();
+
// QuantOps
quant::createConvertSimulatedQuantPass();
quant::createConvertConstPass();
diff --git a/mlir/lib/Dialect/LoopOps/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/CMakeLists.txt
index 27f50068ec07..4e9ef44eebeb 100644
--- a/mlir/lib/Dialect/LoopOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/LoopOps/CMakeLists.txt
@@ -21,3 +21,5 @@ target_link_libraries(MLIRLoopOps
MLIRStandardOps
LLVMSupport
)
+
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..b18c3656a2d9
--- /dev/null
+++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_llvm_library(MLIRLoopOpsTransforms
+ ParallelLoopFusion.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LoopOps
+ )
+
+target_link_libraries(MLIRLoopOpsTransforms
+ MLIRPass
+ MLIRLoopOps
+ )
diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
new file mode 100644
index 000000000000..628423a65639
--- /dev/null
+++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
@@ -0,0 +1,182 @@
+//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop fusion on parallel loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/LoopOps/Passes.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using loop::ParallelOp;
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(ParallelOp ploop) {
+ auto walkResult =
+ ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
+ return walkResult.wasInterrupted();
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+ ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
+ matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
+ matchOperands(firstPloop.step(), secondPloop.step());
+}
+
+/// Returns true if the defining operation for the memref is inside the body
+/// of parallel loop.
+bool isDefinedInPloopBody(Value memref, ParallelOp ploop) {
+ auto *memrefDef = memref.getDefiningOp();
+ return memrefDef && ploop.getOperation()->isAncestor(memrefDef);
+}
+
+// Checks if the parallel loops have mixed access to the same buffers. Returns
+// `true` if the first parallel loop writes to the same indices that the second
+// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+ ParallelOp firstPloop, ParallelOp secondPloop,
+ const BlockAndValueMapping &firstToSecondPloopIndices) {
+ DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ firstPloop.getBody()->walk([&](StoreOp store) {
+ bufferStores[store.getMemRef()].push_back(store.indices());
+ });
+ auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
+ // Stop if the memref is defined in secondPloop body. Careful alias analysis
+ // is needed.
+ auto *memrefDef = load.getMemRef().getDefiningOp();
+ if (memrefDef && memrefDef->getBlock() == load.getOperation()->getBlock())
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(load.getMemRef());
+ if (write == bufferStores.end())
+ return WalkResult::advance();
+
+ // Allow only single write access per buffer.
+ if (write->second.size() != 1)
+ return WalkResult::interrupt();
+
+ // Check that the load indices of secondPloop coincide with store indices of
+ // firstPloop for the same memrefs.
+ auto storeIndices = write->second.front();
+ auto loadIndices = load.indices();
+ if (storeIndices.size() != loadIndices.size())
+ return WalkResult::interrupt();
+ for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+ loadIndices[i])
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+ const BlockAndValueMapping &firstToSecondPloopIndices) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
+ firstToSecondPloopIndices))
+ return failure();
+
+ BlockAndValueMapping secondToFirstPloopIndices;
+ secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+ firstPloop.getBody()->getArguments());
+ return success(haveNoReadsAfterWriteExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices));
+}
+
+static bool
+isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+ const BlockAndValueMapping &firstToSecondPloopIndices) {
+ return !hasNestedParallelOp(firstPloop) &&
+ !hasNestedParallelOp(secondPloop) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
+ succeeded(verifyDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices));
+}
+
+/// Prepends operations of firstPloop's body into secondPloop's body.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+ OpBuilder b) {
+ BlockAndValueMapping firstToSecondPloopIndices;
+ firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
+ secondPloop.getBody()->getArguments());
+
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
+ return;
+
+ b.setInsertionPointToStart(secondPloop.getBody());
+ for (auto &op : firstPloop.getBody()->without_terminator())
+ b.clone(op, firstToSecondPloopIndices);
+ firstPloop.erase();
+}
+
+static void naivelyFuseParallelOps(Operation *op) {
+ OpBuilder b(op);
+ // Consider every single block and attempt to fuse adjacent loops.
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region.getBlocks()) {
+ SmallVector<SmallVector<ParallelOp, 8>, 1> ploop_chains{{}};
+ // Not using `walk()` to traverse only top-level parallel loops and also
+ // make sure that there are no side-effecting ops between the parallel
+ // loops.
+ bool noSideEffects = true;
+ for (auto &op : block.getOperations()) {
+ if (auto ploop = dyn_cast<ParallelOp>(op)) {
+ if (noSideEffects) {
+ ploop_chains.back().push_back(ploop);
+ } else {
+ ploop_chains.push_back({ploop});
+ noSideEffects = true;
+ }
+ continue;
+ }
+ noSideEffects &= op.hasNoSideEffect();
+ }
+ for (ArrayRef<ParallelOp> ploops : ploop_chains) {
+ llvm::errs() << "poo size = " << ploops.size() << '\n';
+ for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
+ fuseIfLegal(ploops[i], ploops[i + 1], b);
+ }
+ }
+ }
+}
+
+namespace {
+
+struct ParallelLoopFusion : public OperationPass<ParallelLoopFusion> {
+ void runOnOperation() override { naivelyFuseParallelOps(getOperation()); }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
+ return std::make_unique<ParallelLoopFusion>();
+}
+
+static PassRegistration<ParallelLoopFusion>
+ pass("parallel-loop-fusion", "Fuse adjacent parallel loops.");
diff --git a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
new file mode 100644
index 000000000000..86919228c405
--- /dev/null
+++ b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
@@ -0,0 +1,309 @@
+// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-fusion)' -split-input-file | FileCheck %s --dump-input-on-failure
+
+func @fuse_empty_loops() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @fuse_empty_loops
+// CHECK: [[C2:%.*]] = constant 2 : index
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: "loop.terminator"() : () -> ()
+// CHECK: }
+// CHECK-NOT: loop.parallel
+
+// -----
+
+func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
+ %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %sum = alloc() : memref<2x2xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = load %B[%i, %j] : memref<2x2xf32>
+ %C_elem = load %C[%i, %j] : memref<2x2xf32>
+ %sum_elem = addf %B_elem, %C_elem : f32
+ store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = mulf %sum_elem, %A_elem : f32
+ store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_two
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
+// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
+// CHECK: [[C2:%.*]] = constant 2 : index
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[SUM:%.*]] = alloc()
+// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
+// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: "loop.terminator"() : () -> ()
+// CHECK: }
+// CHECK: dealloc [[SUM]]
+
+// -----
+
+func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
+ %result: memref<100x10xf32>) {
+ %c100 = constant 100 : index
+ %c10 = constant 10 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %broadcast_rhs = alloc() : memref<100x10xf32>
+ %
diff = alloc() : memref<100x10xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+ %rhs_elem = load %rhs[%i] : memref<100xf32>
+ store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+ %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32>
+ %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32>
+ %
diff _elem = subf %lhs_elem, %broadcast_rhs_elem : f32
+ store %
diff _elem, %
diff [%i, %j] : memref<100x10xf32>
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
+ %
diff _elem = load %
diff [%i, %j] : memref<100x10xf32>
+ %exp_elem = exp %
diff _elem : f32
+ store %exp_elem, %result[%i, %j] : memref<100x10xf32>
+ "loop.terminator"() : () -> ()
+ }
+ dealloc %broadcast_rhs : memref<100x10xf32>
+ dealloc %
diff : memref<100x10xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_three
+// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
+// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
+// CHECK: [[C100:%.*]] = constant 100 : index
+// CHECK: [[C10:%.*]] = constant 10 : index
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[BROADCAST_RHS:%.*]] = alloc()
+// CHECK: [[DIFF:%.*]] = alloc()
+// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
+// CHECK: [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]]
+// CHECK: store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK: [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]]
+// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]]
+// CHECK: [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
+// CHECK: store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]]
+// CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: "loop.terminator"() : () -> ()
+// CHECK: }
+// CHECK: dealloc [[BROADCAST_RHS]]
+// CHECK: dealloc [[DIFF]]
+
+// -----
+
+func @do_not_fuse_nested_ploop1() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_nested_ploop1
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_nested_ploop2() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_nested_ploop2
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_unmatching_num_loops() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i) = (%c0) to (%c2) step (%c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ %buffer = alloc() : memref<2x2xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_unmatching_iteration_space() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_unmatching_write_read_patterns(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>,
+ %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %common_buf = alloc() : memref<2x2xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = load %B[%i, %j] : memref<2x2xf32>
+ %C_elem = load %C[%i, %j] : memref<2x2xf32>
+ %sum_elem = addf %B_elem, %C_elem : f32
+ store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %k = addi %i, %c1 : index
+ %sum_elem = load %common_buf[%k, %j] : memref<2x2xf32>
+ %A_elem = load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = mulf %sum_elem, %A_elem : f32
+ store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ dealloc %common_buf : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_unmatching_read_write_patterns(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %sum = alloc() : memref<2x2xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = load %B[%i, %j] : memref<2x2xf32>
+ %C_elem = load %common_buf[%i, %j] : memref<2x2xf32>
+ %sum_elem = addf %B_elem, %C_elem : f32
+ store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %k = addi %i, %c1 : index
+ %sum_elem = load %sum[%k, %j] : memref<2x2xf32>
+ %A_elem = load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = mulf %sum_elem, %A_elem : f32
+ store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
+ "loop.terminator"() : () -> ()
+ }
+ dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns
+// CHECK: loop.parallel
+// CHECK: loop.parallel
+
+// -----
+
+func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %buffer = alloc() : memref<2x2xf32>
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ "loop.terminator"() : () -> ()
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
+ : memref<2x2xf32> to memref<?x?xf32, offset: ?, strides:[?, ?]>
+ %A_elem = load %A[%i, %j] : memref<?x?xf32, offset: ?, strides:[?, ?]>
+ "loop.terminator"() : () -> ()
+ }
+ "xla_lhlo.terminator"() : () -> ()
+}
+// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
+// CHECK: loop.parallel
+// CHECK: loop.parallel
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index eda365215cb1..fad641cece2d 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -19,6 +19,7 @@ target_link_libraries(MLIRMlirOptMain
)
set(LIBS
+ MLIRLoopOpsTransforms
MLIRLoopAnalysis
MLIRAnalysis
MLIRAffineOps
More information about the Mlir-commits
mailing list