[Mlir-commits] [mlir] af5e83f - [MLIR] Introduce utility to hoist affine if/else conditions
Uday Bondhugula
llvmlistbot at llvm.org
Wed Apr 15 12:02:58 PDT 2020
Author: Uday Bondhugula
Date: 2020-04-16T00:32:34+05:30
New Revision: af5e83f569819bab68a070ca59128651feefb7ef
URL: https://github.com/llvm/llvm-project/commit/af5e83f569819bab68a070ca59128651feefb7ef
DIFF: https://github.com/llvm/llvm-project/commit/af5e83f569819bab68a070ca59128651feefb7ef.diff
LOG: [MLIR] Introduce utility to hoist affine if/else conditions
This revision introduces a utility to unswitch affine.for/parallel loops
by hoisting affine.if operations past surrounding affine.for/parallel.
The hoisting works for both perfect/imperfect nests and in the presence
of else blocks. The hoisting is currently to as outermost a level as
possible. Uses a test pass to test the utility.
Add convenience method Operation::getParentWithTrait<Trait>.
Depends on D77487.
Differential Revision: https://reviews.llvm.org/D77870
Added:
mlir/include/mlir/Dialect/Affine/Utils.h
mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/test/Dialect/Affine/loop-unswitch.mlir
mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/lib/Dialect/Affine/CMakeLists.txt
mlir/test/lib/Dialect/Affine/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d30e6b26ff2c..cb827395840a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -337,13 +337,16 @@ def AffineIfOp : Affine_Op<"if",
/// list of AffineIf is not resizable.
void setConditional(IntegerSet set, ValueRange operands);
+ /// Returns true if an else block exists.
+ bool hasElse() { return !elseRegion().empty(); }
+
Block *getThenBlock() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
return &thenRegion().front();
}
Block *getElseBlock() {
- assert(!elseRegion().empty() && "Empty 'else' region.");
+ assert(hasElse() && "Empty 'else' region.");
return &elseRegion().front();
}
@@ -353,7 +356,7 @@ def AffineIfOp : Affine_Op<"if",
return OpBuilder(&body, std::prev(body.end()));
}
OpBuilder getElseBodyBuilder() {
- assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
+ assert(hasElse() && "No 'else' block");
Block &body = elseRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
@@ -491,6 +494,9 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
Block *getBody();
OpBuilder getBodyBuilder();
+ MutableArrayRef<BlockArgument> getIVs() {
+ return getBody()->getArguments();
+ }
void setSteps(ArrayRef<int64_t> newSteps);
static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
new file mode 100644
index 000000000000..a2c0211b301e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -0,0 +1,29 @@
+//===- Utils.h - Affine dialect utilities -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a set of utilities for the affine dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AFFINE_UTILS_H
+#define MLIR_DIALECT_AFFINE_UTILS_H
+
+namespace mlir {
+
+class AffineIfOp;
+struct LogicalResult;
+
+/// Hoists out affine.if/else to as high as possible, i.e., past all invariant
+/// affine.fors/parallel's. Returns success if any hoisting happened; folded` is
+/// set to true if the op was folded or erased. This hoisting could lead to
+/// significant code expansion in some cases.
+LogicalResult hoistAffineIfOp(AffineIfOp ifOp, bool *folded = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AFFINE_UTILS_H
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index a503c1ec13f0..f2b174c067ff 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -116,6 +116,12 @@ class OpState {
return getOperation()->getParentOfType<OpTy>();
}
+ /// Returns the closest surrounding parent operation with trait `Trait`.
+ template <template <typename T> class Trait>
+ Operation *getParentWithTrait() {
+ return getOperation()->getParentWithTrait<Trait>();
+ }
+
/// Return the context this operation belongs to.
MLIRContext *getContext() { return getOperation()->getContext(); }
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 099ca3c610c4..f0be03a5cced 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -126,6 +126,16 @@ class Operation final
return OpTy();
}
+ /// Returns the closest surrounding parent operation with trait `Trait`.
+ template <template <typename T> class Trait>
+ Operation *getParentWithTrait() {
+ Operation *op = this;
+ while ((op = op->getParentOp()))
+ if (op->hasTrait<Trait>())
+ return op;
+ return nullptr;
+ }
+
/// Return true if this operation is a proper ancestor of the `other`
/// operation.
bool isProperAncestor(Operation *other);
diff --git a/mlir/lib/Dialect/Affine/CMakeLists.txt b/mlir/lib/Dialect/Affine/CMakeLists.txt
index c018b50f967f..95cf0a44f21b 100644
--- a/mlir/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/CMakeLists.txt
@@ -19,3 +19,4 @@ target_link_libraries(MLIRAffine
)
add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
new file mode 100644
index 000000000000..64738c0cf369
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRAffineUtils
+ Utils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
+
+ )
+target_link_libraries(MLIRAffineUtils
+ PUBLIC
+ MLIRAffine
+ )
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
new file mode 100644
index 000000000000..811579bb6c8c
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -0,0 +1,175 @@
+//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous transformation utilities for the Affine
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
+/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
+/// the rest of the op.
+static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
+ if (elseBlock)
+ assert(ifOp.hasElse() && "else block expected");
+
+ Block *destBlock = ifOp.getOperation()->getBlock();
+ Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
+ destBlock->getOperations().splice(
+ Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
+ std::prev(srcBlock->end()));
+ ifOp.erase();
+}
+
+/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
+/// on. The `ifOp` could be hoisted and placed right before such an operation.
+/// This method assumes that the ifOp has been canonicalized (to be correct and
+/// effective).
+static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
+ // Walk up the parents past all for op that this conditional is invariant on.
+ auto ifOperands = ifOp.getOperands();
+ auto res = ifOp.getOperation();
+ while (!isa<FuncOp>(res->getParentOp())) {
+ auto *parentOp = res->getParentOp();
+ if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
+ if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
+ break;
+ } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
+ for (auto iv : parallelOp.getIVs())
+ if (llvm::is_contained(ifOperands, iv))
+ break;
+ } else if (!isa<AffineIfOp>(parentOp)) {
+ // Won't walk up past anything other than affine.for/if ops.
+ break;
+ }
+ // You can always hoist up past any affine.if ops.
+ res = parentOp;
+ }
+ return res;
+}
+
+/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
+/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
+/// otherwise the same `ifOp`.
+static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
+ // No hoisting to do.
+ if (hoistOverOp == ifOp)
+ return ifOp;
+
+ // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
+ // the else block. Then drop the else block of the original 'if' in the 'then'
+ // branch while promoting its then block, and analogously drop the 'then'
+ // block of the original 'if' from the 'else' branch while promoting its else
+ // block.
+ BlockAndValueMapping operandMap;
+ OpBuilder b(hoistOverOp);
+ auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
+ ifOp.getOperands(),
+ /*elseBlock=*/true);
+
+ // Create a clone of hoistOverOp to use for the else branch of the hoisted
+ // conditional. The else block may get optimized away if empty.
+ Operation *hoistOverOpClone = nullptr;
+ // We use this unique name to identify/find `ifOp`'s clone in the else
+ // version.
+ Identifier idForIfOp = b.getIdentifier("__mlir_if_hoisting");
+ operandMap.clear();
+ b.setInsertionPointAfter(hoistOverOp);
+ // We'll set an attribute to identify this op in a clone of this sub-tree.
+ ifOp.setAttr(idForIfOp, b.getBoolAttr(true));
+ hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
+
+ // Promote the 'then' block of the original affine.if in the then version.
+ promoteIfBlock(ifOp, /*elseBlock=*/false);
+
+ // Move the then version to the hoisted if op's 'then' block.
+ auto *thenBlock = hoistedIfOp.getThenBlock();
+ thenBlock->getOperations().splice(thenBlock->begin(),
+ hoistOverOp->getBlock()->getOperations(),
+ Block::iterator(hoistOverOp));
+
+ // Find the clone of the original affine.if op in the else version.
+ AffineIfOp ifCloneInElse;
+ hoistOverOpClone->walk([&](AffineIfOp ifClone) {
+ if (!ifClone.getAttr(idForIfOp))
+ return WalkResult::advance();
+ ifCloneInElse = ifClone;
+ return WalkResult::interrupt();
+ });
+ assert(ifCloneInElse && "if op clone should exist");
+ // For the else block, promote the else block of the original 'if' if it had
+ // one; otherwise, the op itself is to be erased.
+ if (!ifCloneInElse.hasElse())
+ ifCloneInElse.erase();
+ else
+ promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
+
+ // Move the else version into the else block of the hoisted if op.
+ auto *elseBlock = hoistedIfOp.getElseBlock();
+ elseBlock->getOperations().splice(
+ elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
+ Block::iterator(hoistOverOpClone));
+
+ return hoistedIfOp;
+}
+
+// Returns success if any hoisting happened.
+LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
+ // Apply canonicalization patterns and folding - this is necessary for the
+ // hoisting check to be correct (operands should be composed), and to be more
+ // effective (no unused operands). Since the pattern rewriter's folding is
+ // entangled with application of patterns, we may fold/end up erasing the op,
+ // in which case we return with `folded` being set.
+ OwningRewritePatternList patterns;
+ AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
+ bool erased;
+ applyOpPatternsAndFold(ifOp, patterns, &erased);
+ if (erased) {
+ if (folded)
+ *folded = true;
+ return failure();
+ }
+ if (folded)
+ *folded = false;
+
+ // The folding above should have ensured this, but the affine.if's
+ // canonicalization is missing composition of affine.applys into it.
+ assert(llvm::all_of(ifOp.getOperands(),
+ [](Value v) {
+ return isTopLevelValue(v) || isForInductionVar(v);
+ }) &&
+ "operands not composed");
+
+ // We are going hoist as high as possible.
+ // TODO: this could be customized in the future.
+ auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
+
+ AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
+ // Nothing to hoist over.
+ if (hoistedIfOp == ifOp)
+ return failure();
+
+ // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
+ // a sequence of affine.fors that are all perfectly nested).
+ applyPatternsAndFoldGreedily(
+ hoistedIfOp.getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
+ std::move(patterns));
+
+ return success();
+}
diff --git a/mlir/test/Dialect/Affine/loop-unswitch.mlir b/mlir/test/Dialect/Affine/loop-unswitch.mlir
new file mode 100644
index 000000000000..801eb059511c
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-unswitch.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt %s -split-input-file -test-affine-loop-unswitch | FileCheck %s
+
+// CHECK-DAG: #[[SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)>
+
+// CHECK-LABEL: func @if_else_imperfect
+func @if_else_imperfect(%A : memref<100xi32>, %B : memref<100xi32>, %v : i32) {
+// CHECK: %[[A:.*]]: memref<100xi32>, %[[B:.*]]: memref
+ affine.for %i = 0 to 100 {
+ affine.load %A[%i] : memref<100xi32>
+ affine.for %j = 0 to 100 {
+ affine.load %A[%j] : memref<100xi32>
+ affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
+ affine.load %B[%j] : memref<100xi32>
+ }
+ call @external() : () -> ()
+ }
+ affine.load %A[%i] : memref<100xi32>
+ }
+ return
+}
+func @external()
+
+// CHECK: affine.for %[[I:.*]] = 0 to 100 {
+// CHECK-NEXT: affine.load %[[A]][%[[I]]]
+// CHECK-NEXT: affine.if #[[SET]](%[[I]]) {
+// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 100 {
+// CHECK-NEXT: affine.load %[[A]][%[[J]]]
+// CHECK-NEXT: affine.load %[[B]][%[[J]]]
+// CHECK-NEXT: call
+// CHECK-NEXT: }
+// CHECK-NEXT: } else {
+// CHECK-NEXT: affine.for %[[JJ:.*]] = 0 to 100 {
+// CHECK-NEXT: affine.load %[[A]][%[[JJ]]]
+// CHECK-NEXT: call
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: affine.load %[[A]][%[[I]]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// -----
+
+func @foo()
+func @bar()
+func @abc()
+func @xyz()
+
+// CHECK-LABEL: func @if_then_perfect
+func @if_then_perfect(%A : memref<100xi32>, %v : i32) {
+ affine.for %i = 0 to 100 {
+ affine.for %j = 0 to 100 {
+ affine.for %k = 0 to 100 {
+ affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
+ affine.load %A[%i] : memref<100xi32>
+ }
+ }
+ }
+ }
+ return
+}
+// CHECK: affine.for
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NOT: else
+
+
+// CHECK-LABEL: func @if_else_perfect
+func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
+ affine.for %i = 0 to 99 {
+ affine.for %j = 0 to 100 {
+ affine.for %k = 0 to 100 {
+ call @foo() : () -> ()
+ affine.if affine_set<(d0, d1) : (d0 - 2 >= 0, -d1 + 80 >= 0)>(%i, %j) {
+ affine.load %A[%i] : memref<100xi32>
+ call @abc() : () -> ()
+ } else {
+ affine.load %A[%i + 1] : memref<100xi32>
+ call @xyz() : () -> ()
+ }
+ call @bar() : () -> ()
+ }
+ }
+ }
+ return
+}
+// CHECK: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: call @foo
+// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}]
+// CHECK-NEXT: call @abc
+// CHECK-NEXT: call @bar
+// CHECK-NEXT: }
+// CHECK-NEXT: else
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: call @foo
+// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} + 1]
+// CHECK-NEXT: call @xyz
+// CHECK-NEXT: call @bar
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+// CHECK-LABEL: func @if_then_imperfect
+func @if_then_imperfect(%A : memref<100xi32>, %N : index) {
+ affine.for %i = 0 to 100 {
+ affine.load %A[0] : memref<100xi32>
+ affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%N) {
+ affine.load %A[%i] : memref<100xi32>
+ }
+ }
+ return
+}
+// CHECK: affine.if
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: }
+// CHECK-NEXT: } else {
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// Check if unused operands are dropped: hence, hoisting is possible.
+// CHECK-LABEL: func @hoist_after_canonicalize
+func @hoist_after_canonicalize() {
+ affine.for %i = 0 to 100 {
+ affine.for %j = 0 to 100 {
+ affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%j) {
+ affine.if affine_set<(d0, d1) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%i, %j) {
+ // The call to external is to avoid DCE on affine.if.
+ call @foo() : () -> ()
+ }
+ }
+ }
+ }
+ return
+}
+// CHECK: affine.for
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: call
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// CHECK-LABEL: func @handle_dead_if
+func @handle_dead_if(%N : index) {
+ affine.for %i = 0 to 100 {
+ affine.if affine_set<(d0) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%N) {
+ }
+ }
+ return
+}
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// -----
+
+// A test case with affine.parallel.
+
+#flb1 = affine_map<(d0) -> (d0 * 3)>
+#fub1 = affine_map<(d0) -> (d0 * 3 + 3)>
+#flb0 = affine_map<(d0) -> (d0 * 16)>
+#fub0 = affine_map<(d0) -> (d0 * 16 + 16)>
+#pub1 = affine_map<(d0)[s0] -> (s0, d0 * 3 + 3)>
+#pub0 = affine_map<(d0)[s0] -> (s0, d0 * 16 + 16)>
+#lb1 = affine_map<(d0) -> (d0 * 480)>
+#ub1 = affine_map<(d0)[s0] -> (s0, d0 * 480 + 480)>
+#lb0 = affine_map<(d0) -> (d0 * 110)>
+#ub0 = affine_map<(d0)[s0] -> (d0 * 110 + 110, s0 floordiv 3)>
+
+#set0 = affine_set<(d0, d1)[s0, s1] : (d0 * -16 + s0 - 16 >= 0, d1 * -3 + s1 - 3 >= 0)>
+
+// CHECK-LABEL: func @perfect_if_else
+func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 : index,
+ %arg5 : index, %arg6 : index, %sym : index) {
+ affine.for %arg7 = #lb0(%arg5) to min #ub0(%arg5)[%sym] {
+ affine.parallel (%i0, %j0) = (0, 0) to (symbol(%sym), 100) step (10, 10) {
+ affine.for %arg8 = #lb1(%arg4) to min #ub1(%arg4)[%sym] {
+ affine.if #set0(%arg6, %arg7)[%sym, %sym] {
+ affine.for %arg9 = #flb0(%arg6) to #fub0(%arg6) {
+ affine.for %arg10 = #flb1(%arg7) to #fub1(%arg7) {
+ affine.load %arg0[0, 0] : memref<?x?xf64>
+ }
+ }
+ } else {
+ affine.for %arg9 = #lb0(%arg6) to min #pub0(%arg6)[%sym] {
+ affine.for %arg10 = #lb1(%arg7) to min #pub1(%arg7)[%sym] {
+ affine.load %arg0[0, 0] : memref<?x?xf64>
+ }
+ }
+ }
+ }
+ }
+ }
+ return
+}
+
+// CHECK: affine.for
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: affine.parallel
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: } else {
+// CHECK-NEXT: affine.parallel
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+// With multiple if ops in a function, the test pass just looks for the first if
+// op that it is able to successfully hoist.
+
+// CHECK-LABEL: func @multiple_if
+func @multiple_if(%N : index) {
+ affine.if affine_set<() : (0 == 0)>() {
+ call @external() : () -> ()
+ }
+ affine.for %i = 0 to 100 {
+ affine.if affine_set<()[s0] : (s0 >= 0)>()[%N] {
+ call @external() : () -> ()
+ }
+ }
+ return
+}
+// CHECK: affine.if
+// CHECK-NEXT: call
+// CHECK-NEXT: }
+// CHECK-NEXT: affine.if
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: call
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+func @external()
diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index fb1346a8fb5b..296e431e5414 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -1,5 +1,6 @@
add_llvm_library(MLIRAffineTransformsTestPasses
TestAffineDataCopy.cpp
+ TestAffineLoopUnswitching.cpp
TestLoopPermutation.cpp
TestParallelismDetection.cpp
TestVectorizationUtils.cpp
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
new file mode 100644
index 000000000000..69ca5ce96da1
--- /dev/null
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
@@ -0,0 +1,60 @@
+//===- TestAffineLoopUnswitching.cpp - Test affine if/else hoisting -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to hoist affine if/else structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#define PASS_NAME "test-affine-loop-unswitch"
+
+using namespace mlir;
+
+namespace {
+
+/// This pass applies the permutation on the first maximal perfect nest.
+struct TestAffineLoopUnswitching
+ : public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
+ TestAffineLoopUnswitching() = default;
+ TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
+
+ void runOnFunction() override;
+
+ /// The maximum number of iterations to run this for.
+ constexpr static unsigned kMaxIterations = 5;
+};
+
+} // end anonymous namespace
+
+void TestAffineLoopUnswitching::runOnFunction() {
+ // Each hoisting invalidates a lot of IR around. Just stop the walk after the
+ // first if/else hoisting, and repeat until no more hoisting can be done, or
+ // the maximum number of iterations have been run.
+ auto func = getFunction();
+ unsigned i = 0;
+ do {
+ auto walkFn = [](AffineIfOp op) {
+ return succeeded(hoistAffineIfOp(op)) ? WalkResult::interrupt()
+ : WalkResult::advance();
+ };
+ if (func.walk(walkFn).wasInterrupted())
+ break;
+ } while (++i < kMaxIterations);
+}
+
+namespace mlir {
+void registerTestAffineLoopUnswitchingPass() {
+ PassRegistration<TestAffineLoopUnswitching>(
+ PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index fe92a84103aa..fe8ae83a8154 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -40,6 +40,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
+void registerTestAffineLoopUnswitchingPass();
void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
@@ -103,6 +104,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
+ registerTestAffineLoopUnswitchingPass();
registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();
More information about the Mlir-commits
mailing list