[Mlir-commits] [mlir] efc0ba0 - [MLIR][Transform] Introduce loop.coalesce transform op.
Prabhdeep Singh Soni
llvmlistbot at llvm.org
Tue Jan 17 06:41:45 PST 2023
Author: Amy Wang
Date: 2023-01-17T09:38:47-05:00
New Revision: efc0ba0275bd39fe0fb7548345139e9183fce20d
URL: https://github.com/llvm/llvm-project/commit/efc0ba0275bd39fe0fb7548345139e9183fce20d
DIFF: https://github.com/llvm/llvm-project/commit/efc0ba0275bd39fe0fb7548345139e9183fce20d.diff
LOG: [MLIR][Transform] Introduce loop.coalesce transform op.
This patch made a minor refactor of LoopCoalescing.cpp's walkLoops
templated method and placed it in Affine's LoopUtils.cpp/h.
This method is also renamed as coalescePerfectlyNestedLoops method. This
minor change enables this method to be invoked
by both the original LoopCoalescing pass as well as the newly introduced
loop.coalesce transform op.
The loop.coalesce transform op has the ability to coalesce affine, and
scf loop nests, leveraging existing LoopCoalescing
mechanism. I have created it inside the SCFTransformOps.td instead of
AffineTransformOps.td as it feels to be similar
in spirit as the loop.unroll op that can handle both scf and affine
loops. Please let me know if you feel that this op
should be moved into AffineTransformOps.td instead.
The testcase added illustrates loop.coalesce transform op working for
scf, affine loops (inner, outer) as well as
coalesced loop can be further unrolled (achieving composibility).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D141202
Added:
mlir/test/Dialect/SCF/transform-op-coalesce.mlir
mlir/test/Dialect/SCF/transform-ops-invalid.mlir
Modified:
mlir/include/mlir/Dialect/Affine/LoopUtils.h
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/transform-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 2e3a5876ee62c..f598625e565b1 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -18,6 +18,7 @@
#include "mlir/IR/Block.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
#include <optional>
namespace mlir {
@@ -293,6 +294,54 @@ LogicalResult
separateFullTiles(MutableArrayRef<AffineForOp> nest,
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
+/// Walk either an scf.for or an affine.for to find a band to coalesce.
+template <typename LoopOpTy>
+LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
+ LogicalResult result(failure());
+ SmallVector<LoopOpTy> loops;
+ getPerfectlyNestedLoops(loops, op);
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+ // 1. For each loop, find above which parent loop its operands are
+ // defined.
+ SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
+ }
+
+ // 2. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ assert(maxPos == start &&
+ "expected loop bounds to be known at the start of the band");
+ auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
+}
+
} // namespace mlir
#endif // MLIR_DIALECT_AFFINE_LOOPUTILS_H
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index dd7da91d8eb79..affa9abac31a0 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -189,4 +189,31 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
}];
}
+def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
+ FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Coalesces the perfect loop nest enclosed by a given loop";
+ let description = [{
+ Given a perfect loop nest identified by the outermost loop,
+ perform loop coalescing in a bottom-up one-by-one manner.
+
+ #### Return modes
+
+ The return handle points to the coalesced loop if coalescing happens, or
+ the given input loop if coalescing does not happen.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, $transformed)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // SCF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index a7fada76ac5ba..4d4baa9ae221d 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -121,7 +121,7 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
/// Replace a perfect nest of "for" loops with a single linearized loop. Assumes
/// `loops` contains a list of perfectly nested loops with bounds and steps
/// independent of any loop induction variable involved in the nest.
-void coalesceLoops(MutableArrayRef<scf::ForOp> loops);
+LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
/// Take the ParallelLoop and for each set of dimension indices, combine them
/// into a single dimension. combinedDimensions must contain each index into
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index c8c82400bac82..13092701f376e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -32,72 +32,13 @@ namespace {
struct LoopCoalescingPass
: public impl::LoopCoalescingBase<LoopCoalescingPass> {
- /// Walk either an scf.for or an affine.for to find a band to coalesce.
- template <typename LoopOpTy>
- static void walkLoop(LoopOpTy op) {
- // Ignore nested loops.
- if (op->template getParentOfType<LoopOpTy>())
- return;
-
- SmallVector<LoopOpTy, 4> loops;
- getPerfectlyNestedLoops(loops, op);
- LLVM_DEBUG(llvm::dbgs()
- << "found a perfect nest of depth " << loops.size() << '\n');
-
- // Look for a band of loops that can be coalesced, i.e. perfectly nested
- // loops with bounds defined above some loop.
- // 1. For each loop, find above which parent loop its operands are
- // defined.
- SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
- for (unsigned i = 0, e = loops.size(); i < e; ++i) {
- operandsDefinedAbove[i] = i;
- for (unsigned j = 0; j < i; ++j) {
- if (areValuesDefinedAbove(loops[i].getOperands(),
- loops[j].getRegion())) {
- operandsDefinedAbove[i] = j;
- break;
- }
- }
- LLVM_DEBUG(llvm::dbgs()
- << " bounds of loop " << i << " are known above depth "
- << operandsDefinedAbove[i] << '\n');
- }
-
- // 2. Identify bands of loops such that the operands of all of them are
- // defined above the first loop in the band. Traverse the nest bottom-up
- // so that modifications don't invalidate the inner loops.
- for (unsigned end = loops.size(); end > 0; --end) {
- unsigned start = 0;
- for (; start < end - 1; ++start) {
- auto maxPos =
- *std::max_element(std::next(operandsDefinedAbove.begin(), start),
- std::next(operandsDefinedAbove.begin(), end));
- if (maxPos > start)
- continue;
-
- assert(maxPos == start &&
- "expected loop bounds to be known at the start of the band");
- LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start
- << " to " << end << '\n');
-
- auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
- (void)coalesceLoops(band);
- break;
- }
- // If a band was found and transformed, keep looking at the loops above
- // the outermost transformed loop.
- if (start != end - 1)
- end = start + 1;
- }
- }
-
void runOnOperation() override {
func::FuncOp func = getOperation();
- func.walk([&](Operation *op) {
+ func.walk<WalkOrder::PreOrder>([](Operation *op) {
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
- walkLoop(scfForOp);
+ (void)coalescePerfectlyNestedLoops(scfForOp);
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
- walkLoop(affineForOp);
+ (void)coalescePerfectlyNestedLoops(affineForOp);
});
}
};
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 5860086e588bc..ec85e566e778b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -25,7 +25,6 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 1ee5a02c8f4a6..5477af7ab52e6 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -219,9 +219,32 @@ transform::LoopUnrollOp::applyToOne(Operation *op,
result = loopUnrollByFactor(affineFor, getFactor());
if (failed(result)) {
- Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note);
- diag << "Op failed to unroll";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to unroll";
+ return diag;
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopCoalesceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::LoopCoalesceOp::applyToOne(Operation *op,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ LogicalResult result(failure());
+ if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
+ result = coalescePerfectlyNestedLoops(scfForOp);
+ else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
+ result = coalescePerfectlyNestedLoops(affineForOp);
+
+ results.push_back(op);
+ if (failed(result)) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to coalesce";
+ return diag;
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index b4c60b69849a0..6eca0ef9f69cf 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -656,9 +656,9 @@ static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
loop.setStep(loopPieces.step);
}
-void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
if (loops.size() < 2)
- return;
+ return failure();
scf::ForOp innermost = loops.back();
scf::ForOp outermost = loops.front();
@@ -710,6 +710,7 @@ void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
Block::iterator(second.getOperation()),
innermost.getBody()->getOperations());
second.erase();
+ return success();
}
void mlir::collapseParallelLoops(
diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
new file mode 100644
index 0000000000000..4c84f62f9802b
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @coalesce_inner() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ // CHECK: scf.for %[[IV0:.+]]
+ // CHECK: scf.for %[[IV1:.+]]
+ // CHECK: scf.for %[[IV2:.+]]
+ // CHECK-NOT: scf.for %[[IV3:.+]]
+ scf.for %i = %c0 to %c10 step %c1 {
+ scf.for %j = %c0 to %c10 step %c1 {
+ scf.for %k = %i to %j step %c1 {
+ // Inner loop must have been removed.
+ scf.for %l = %i to %j step %c1 {
+ arith.addi %i, %j : index
+ }
+ } {coalesce}
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1
+ %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+}
+
+// -----
+
+func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+ // CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] {
+ // CHECK-NOT: affine.for %[[IV2:.+]]
+ affine.for %arg4 = 0 to 64 {
+ affine.for %arg5 = 0 to 64 {
+ // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}]
+ // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}]
+ // CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
+ %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
+ %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
+ %2 = arith.addf %0, %1 : f32
+ affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1>
+ }
+ } {coalesce}
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1
+ %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"affine.for">) -> (!transform.op<"affine.for">)
+}
+
+// -----
+
+func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+ // CHECK: scf.for %[[IV1:.+]] =
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c64 = arith.constant 64 : index
+
+ scf.for %arg4 = %c0 to %c64 step %c1 {
+ // CHECK-NOT: scf.for
+ scf.for %arg5 = %c0 to %c64 step %c1 {
+ // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]]
+ // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]]
+ // CHECK-NEXT: %{{.+}} = memref.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
+ %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
+ %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
+ %2 = arith.addf %0, %1 : f32
+ // CHECK: memref.store
+ // CHECK: memref.store
+ // CHECK: memref.store
+ // Residual loop must have a single store.
+ // CHECK: memref.store
+ memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1>
+ }
+ } {coalesce}
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1
+ %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for">
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
new file mode 100644
index 0000000000000..57812dee7eed9
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+
+#map0 = affine_map<(d0) -> (d0 * 110)>
+#map1 = affine_map<(d0) -> (696, d0 * 110 + 110)>
+func.func @test_loops_do_not_get_coalesced() {
+ affine.for %i = 0 to 7 {
+ affine.for %j = #map0(%i) to min #map1(%i) {
+ }
+ } {coalesce}
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1
+ %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for">
+ // expected-error @below {{failed to coalesce}}
+ %2 = transform.loop.coalesce %1: (!transform.op<"affine.for">) -> (!transform.op<"affine.for">)
+}
+
+// -----
+
+func.func @test_loops_do_not_get_unrolled() {
+ affine.for %i = 0 to 7 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+ %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+ // expected-error @below {{failed to unroll}}
+ transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func private @cond() -> i1
+func.func private @body()
+
+func.func @loop_outline_op_multi_region() {
+ // expected-note @below {{target op}}
+ scf.while : () -> () {
+ %0 = func.call @cond() : () -> i1
+ scf.condition(%0)
+ } do {
+ ^bb0:
+ func.call @body() : () -> ()
+ scf.yield
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["scf.while"]} in %arg1
+ // expected-error @below {{failed to outline}}
+ transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index d6ff2f2b821dc..0e4b3846369d3 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -84,31 +84,6 @@ transform.sequence failures(propagate) {
// -----
-func.func private @cond() -> i1
-func.func private @body()
-
-func.func @loop_outline_op_multi_region() {
- // expected-note @below {{target op}}
- scf.while : () -> () {
- %0 = func.call @cond() : () -> i1
- scf.condition(%0)
- } do {
- ^bb0:
- func.call @body() : () -> ()
- scf.yield
- }
- return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["scf.while"]} in %arg1
- // expected-error @below {{failed to outline}}
- transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
-}
-
-// -----
-
// CHECK-LABEL: @loop_peel_op
func.func @loop_peel_op() {
// CHECK: %[[C0:.+]] = arith.constant 0
More information about the Mlir-commits
mailing list