[Mlir-commits] [mlir] af5e83f - [MLIR] Introduce utility to hoist affine if/else conditions

Uday Bondhugula llvmlistbot at llvm.org
Wed Apr 15 12:02:58 PDT 2020


Author: Uday Bondhugula
Date: 2020-04-16T00:32:34+05:30
New Revision: af5e83f569819bab68a070ca59128651feefb7ef

URL: https://github.com/llvm/llvm-project/commit/af5e83f569819bab68a070ca59128651feefb7ef
DIFF: https://github.com/llvm/llvm-project/commit/af5e83f569819bab68a070ca59128651feefb7ef.diff

LOG: [MLIR] Introduce utility to hoist affine if/else conditions

This revision introduces a utility to unswitch affine.for/parallel loops
by hoisting affine.if operations past surrounding affine.for/parallel.
The hoisting works for both perfect/imperfect nests and in the presence
of else blocks. The hoisting is currently to as outermost a level as
possible.  Uses a test pass to test the utility.
Add convenience method Operation::getParentWithTrait<Trait>.

Depends on D77487.

Differential Revision: https://reviews.llvm.org/D77870

Added: 
    mlir/include/mlir/Dialect/Affine/Utils.h
    mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/test/Dialect/Affine/loop-unswitch.mlir
    mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/lib/Dialect/Affine/CMakeLists.txt
    mlir/test/lib/Dialect/Affine/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d30e6b26ff2c..cb827395840a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -337,13 +337,16 @@ def AffineIfOp : Affine_Op<"if",
     /// list of AffineIf is not resizable.
     void setConditional(IntegerSet set, ValueRange operands);
 
+    /// Returns true if an else block exists.
+    bool hasElse() { return !elseRegion().empty(); }
+
     Block *getThenBlock() {
       assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
       return &thenRegion().front();
     }
 
     Block *getElseBlock() {
-      assert(!elseRegion().empty() && "Empty 'else' region.");
+      assert(hasElse() && "Empty 'else' region.");
       return &elseRegion().front();
     }
 
@@ -353,7 +356,7 @@ def AffineIfOp : Affine_Op<"if",
       return OpBuilder(&body, std::prev(body.end()));
     }
     OpBuilder getElseBodyBuilder() {
-      assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
+      assert(hasElse() && "No 'else' block");
       Block &body = elseRegion().front();
       return OpBuilder(&body, std::prev(body.end()));
     }
@@ -491,6 +494,9 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
 
     Block *getBody();
     OpBuilder getBodyBuilder();
+    MutableArrayRef<BlockArgument> getIVs() {
+      return getBody()->getArguments();
+    }
     void setSteps(ArrayRef<int64_t> newSteps);
 
     static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }

diff  --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
new file mode 100644
index 000000000000..a2c0211b301e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -0,0 +1,29 @@
+//===- Utils.h - Affine dialect utilities -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a set of utilities for the affine dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AFFINE_UTILS_H
+#define MLIR_DIALECT_AFFINE_UTILS_H
+
+namespace mlir {
+
+class AffineIfOp;
+struct LogicalResult;
+
+/// Hoists out affine.if/else to as high as possible, i.e., past all invariant
+/// affine.fors/parallel's. Returns success if any hoisting happened; folded` is
+/// set to true if the op was folded or erased. This hoisting could lead to
+/// significant code expansion in some cases.
+LogicalResult hoistAffineIfOp(AffineIfOp ifOp, bool *folded = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AFFINE_UTILS_H

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index a503c1ec13f0..f2b174c067ff 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -116,6 +116,12 @@ class OpState {
     return getOperation()->getParentOfType<OpTy>();
   }
 
+  /// Returns the closest surrounding parent operation with trait `Trait`.
+  template <template <typename T> class Trait>
+  Operation *getParentWithTrait() {
+    return getOperation()->getParentWithTrait<Trait>();
+  }
+
   /// Return the context this operation belongs to.
   MLIRContext *getContext() { return getOperation()->getContext(); }
 

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 099ca3c610c4..f0be03a5cced 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -126,6 +126,16 @@ class Operation final
     return OpTy();
   }
 
+  /// Returns the closest surrounding parent operation with trait `Trait`.
+  template <template <typename T> class Trait>
+  Operation *getParentWithTrait() {
+    Operation *op = this;
+    while ((op = op->getParentOp()))
+      if (op->hasTrait<Trait>())
+        return op;
+    return nullptr;
+  }
+
   /// Return true if this operation is a proper ancestor of the `other`
   /// operation.
   bool isProperAncestor(Operation *other);

diff  --git a/mlir/lib/Dialect/Affine/CMakeLists.txt b/mlir/lib/Dialect/Affine/CMakeLists.txt
index c018b50f967f..95cf0a44f21b 100644
--- a/mlir/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/CMakeLists.txt
@@ -19,3 +19,4 @@ target_link_libraries(MLIRAffine
   )
 
 add_subdirectory(Transforms)
+add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
new file mode 100644
index 000000000000..64738c0cf369
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRAffineUtils
+  Utils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
+
+  )
+target_link_libraries(MLIRAffineUtils
+  PUBLIC
+  MLIRAffine
+  )

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
new file mode 100644
index 000000000000..811579bb6c8c
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -0,0 +1,175 @@
+//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous transformation utilities for the Affine
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
+/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
+/// the rest of the op.
+static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
+  if (elseBlock)
+    assert(ifOp.hasElse() && "else block expected");
+
+  Block *destBlock = ifOp.getOperation()->getBlock();
+  Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
+  destBlock->getOperations().splice(
+      Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
+      std::prev(srcBlock->end()));
+  ifOp.erase();
+}
+
+/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
+/// on. The `ifOp` could be hoisted and placed right before such an operation.
+/// This method assumes that the ifOp has been canonicalized (to be correct and
+/// effective).
+static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
+  // Walk up the parents past all for op that this conditional is invariant on.
+  auto ifOperands = ifOp.getOperands();
+  auto res = ifOp.getOperation();
+  while (!isa<FuncOp>(res->getParentOp())) {
+    auto *parentOp = res->getParentOp();
+    if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
+      if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
+        break;
+    } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
+      for (auto iv : parallelOp.getIVs())
+        if (llvm::is_contained(ifOperands, iv))
+          break;
+    } else if (!isa<AffineIfOp>(parentOp)) {
+      // Won't walk up past anything other than affine.for/if ops.
+      break;
+    }
+    // You can always hoist up past any affine.if ops.
+    res = parentOp;
+  }
+  return res;
+}
+
+/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
+/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
+/// otherwise the same `ifOp`.
+static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
+  // No hoisting to do.
+  if (hoistOverOp == ifOp)
+    return ifOp;
+
+  // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
+  // the else block. Then drop the else block of the original 'if' in the 'then'
+  // branch while promoting its then block, and analogously drop the 'then'
+  // block of the original 'if' from the 'else' branch while promoting its else
+  // block.
+  BlockAndValueMapping operandMap;
+  OpBuilder b(hoistOverOp);
+  auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
+                                          ifOp.getOperands(),
+                                          /*elseBlock=*/true);
+
+  // Create a clone of hoistOverOp to use for the else branch of the hoisted
+  // conditional. The else block may get optimized away if empty.
+  Operation *hoistOverOpClone = nullptr;
+  // We use this unique name to identify/find  `ifOp`'s clone in the else
+  // version.
+  Identifier idForIfOp = b.getIdentifier("__mlir_if_hoisting");
+  operandMap.clear();
+  b.setInsertionPointAfter(hoistOverOp);
+  // We'll set an attribute to identify this op in a clone of this sub-tree.
+  ifOp.setAttr(idForIfOp, b.getBoolAttr(true));
+  hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
+
+  // Promote the 'then' block of the original affine.if in the then version.
+  promoteIfBlock(ifOp, /*elseBlock=*/false);
+
+  // Move the then version to the hoisted if op's 'then' block.
+  auto *thenBlock = hoistedIfOp.getThenBlock();
+  thenBlock->getOperations().splice(thenBlock->begin(),
+                                    hoistOverOp->getBlock()->getOperations(),
+                                    Block::iterator(hoistOverOp));
+
+  // Find the clone of the original affine.if op in the else version.
+  AffineIfOp ifCloneInElse;
+  hoistOverOpClone->walk([&](AffineIfOp ifClone) {
+    if (!ifClone.getAttr(idForIfOp))
+      return WalkResult::advance();
+    ifCloneInElse = ifClone;
+    return WalkResult::interrupt();
+  });
+  assert(ifCloneInElse && "if op clone should exist");
+  // For the else block, promote the else block of the original 'if' if it had
+  // one; otherwise, the op itself is to be erased.
+  if (!ifCloneInElse.hasElse())
+    ifCloneInElse.erase();
+  else
+    promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
+
+  // Move the else version into the else block of the hoisted if op.
+  auto *elseBlock = hoistedIfOp.getElseBlock();
+  elseBlock->getOperations().splice(
+      elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
+      Block::iterator(hoistOverOpClone));
+
+  return hoistedIfOp;
+}
+
+// Returns success if any hoisting happened.
+LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
+  // Apply canonicalization patterns and folding - this is necessary for the
+  // hoisting check to be correct (operands should be composed), and to be more
+  // effective (no unused operands). Since the pattern rewriter's folding is
+  // entangled with application of patterns, we may fold/end up erasing the op,
+  // in which case we return with `folded` being set.
+  OwningRewritePatternList patterns;
+  AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
+  bool erased;
+  applyOpPatternsAndFold(ifOp, patterns, &erased);
+  if (erased) {
+    if (folded)
+      *folded = true;
+    return failure();
+  }
+  if (folded)
+    *folded = false;
+
+  // The folding above should have ensured this, but the affine.if's
+  // canonicalization is missing composition of affine.applys into it.
+  assert(llvm::all_of(ifOp.getOperands(),
+                      [](Value v) {
+                        return isTopLevelValue(v) || isForInductionVar(v);
+                      }) &&
+         "operands not composed");
+
+  // We are going hoist as high as possible.
+  // TODO: this could be customized in the future.
+  auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
+
+  AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
+  // Nothing to hoist over.
+  if (hoistedIfOp == ifOp)
+    return failure();
+
+  // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
+  // a sequence of affine.fors that are all perfectly nested).
+  applyPatternsAndFoldGreedily(
+      hoistedIfOp.getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
+      std::move(patterns));
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Affine/loop-unswitch.mlir b/mlir/test/Dialect/Affine/loop-unswitch.mlir
new file mode 100644
index 000000000000..801eb059511c
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-unswitch.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt %s -split-input-file -test-affine-loop-unswitch | FileCheck %s
+
+// CHECK-DAG: #[[SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)>
+
+// CHECK-LABEL: func @if_else_imperfect
+func @if_else_imperfect(%A : memref<100xi32>, %B : memref<100xi32>, %v : i32) {
+// CHECK: %[[A:.*]]: memref<100xi32>, %[[B:.*]]: memref
+  affine.for %i = 0 to 100 {
+    affine.load %A[%i] : memref<100xi32>
+    affine.for %j = 0 to 100 {
+      affine.load %A[%j] : memref<100xi32>
+      affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
+        affine.load %B[%j] : memref<100xi32>
+      }
+      call @external() : () -> ()
+    }
+    affine.load %A[%i] : memref<100xi32>
+  }
+  return
+}
+func @external()
+
+// CHECK:       affine.for %[[I:.*]] = 0 to 100 {
+// CHECK-NEXT:    affine.load %[[A]][%[[I]]]
+// CHECK-NEXT:    affine.if #[[SET]](%[[I]]) {
+// CHECK-NEXT:      affine.for %[[J:.*]] = 0 to 100 {
+// CHECK-NEXT:        affine.load %[[A]][%[[J]]]
+// CHECK-NEXT:        affine.load %[[B]][%[[J]]]
+// CHECK-NEXT:        call
+// CHECK-NEXT:      }
+// CHECK-NEXT:    } else {
+// CHECK-NEXT:      affine.for %[[JJ:.*]] = 0 to 100 {
+// CHECK-NEXT:        affine.load %[[A]][%[[JJ]]]
+// CHECK-NEXT:        call
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:    affine.load %[[A]][%[[I]]]
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return
+
+// -----
+
+func @foo()
+func @bar()
+func @abc()
+func @xyz()
+
+// CHECK-LABEL: func @if_then_perfect
+func @if_then_perfect(%A : memref<100xi32>, %v : i32) {
+  affine.for %i = 0 to 100 {
+    affine.for %j = 0 to 100 {
+      affine.for %k = 0 to 100 {
+        affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
+          affine.load %A[%i] : memref<100xi32>
+        }
+      }
+    }
+  }
+  return
+}
+// CHECK:      affine.for
+// CHECK-NEXT:   affine.if
+// CHECK-NEXT:     affine.for
+// CHECK-NEXT:       affine.for
+// CHECK-NOT:    else
+
+
+// CHECK-LABEL: func @if_else_perfect
+func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
+  affine.for %i = 0 to 99 {
+    affine.for %j = 0 to 100 {
+      affine.for %k = 0 to 100 {
+        call @foo() : () -> ()
+        affine.if affine_set<(d0, d1) : (d0 - 2 >= 0, -d1 + 80 >= 0)>(%i, %j) {
+          affine.load %A[%i] : memref<100xi32>
+          call @abc() : () -> ()
+        } else {
+          affine.load %A[%i + 1] : memref<100xi32>
+          call @xyz() : () -> ()
+        }
+        call @bar() : () -> ()
+      }
+    }
+  }
+  return
+}
+// CHECK:      affine.for
+// CHECK-NEXT:   affine.for
+// CHECK-NEXT:     affine.if
+// CHECK-NEXT:       affine.for
+// CHECK-NEXT:         call @foo
+// CHECK-NEXT:         affine.load %{{.*}}[%{{.*}}]
+// CHECK-NEXT:         call @abc
+// CHECK-NEXT:         call @bar
+// CHECK-NEXT:       }
+// CHECK-NEXT:     else
+// CHECK-NEXT:       affine.for
+// CHECK-NEXT:         call @foo
+// CHECK-NEXT:         affine.load %{{.*}}[%{{.*}} + 1]
+// CHECK-NEXT:         call @xyz
+// CHECK-NEXT:         call @bar
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+// CHECK-LABEL: func @if_then_imperfect
+func @if_then_imperfect(%A : memref<100xi32>, %N : index) {
+  affine.for %i = 0 to 100 {
+    affine.load %A[0] : memref<100xi32>
+    affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%N) {
+      affine.load %A[%i] : memref<100xi32>
+    }
+  }
+  return
+}
+// CHECK:       affine.if
+// CHECK-NEXT:    affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:    }
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return
+
+// Check if unused operands are dropped: hence, hoisting is possible.
+// CHECK-LABEL: func @hoist_after_canonicalize
+func @hoist_after_canonicalize() {
+  affine.for %i = 0 to 100 {
+    affine.for %j = 0 to 100 {
+      affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%j)  {
+        affine.if affine_set<(d0, d1) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%i, %j)  {
+          // The call to external is to avoid DCE on affine.if.
+          call @foo() : () -> ()
+        }
+      }
+    }
+  }
+  return
+}
+// CHECK:      affine.for
+// CHECK-NEXT:   affine.if
+// CHECK-NEXT:     affine.for
+// CHECK-NEXT:       affine.if
+// CHECK-NEXT:         call
+// CHECK-NEXT:       }
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// CHECK-LABEL: func @handle_dead_if
+func @handle_dead_if(%N : index) {
+  affine.for %i = 0 to 100 {
+      affine.if affine_set<(d0) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%N)  {
+      }
+  }
+  return
+}
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// -----
+
+// A test case with affine.parallel.
+
+#flb1 = affine_map<(d0) -> (d0 * 3)>
+#fub1 = affine_map<(d0) -> (d0 * 3 + 3)>
+#flb0 = affine_map<(d0) -> (d0 * 16)>
+#fub0 = affine_map<(d0) -> (d0 * 16 + 16)>
+#pub1 = affine_map<(d0)[s0] -> (s0, d0 * 3 + 3)>
+#pub0 = affine_map<(d0)[s0] -> (s0, d0 * 16 + 16)>
+#lb1 = affine_map<(d0) -> (d0 * 480)>
+#ub1 = affine_map<(d0)[s0] -> (s0, d0 * 480 + 480)>
+#lb0 = affine_map<(d0) -> (d0 * 110)>
+#ub0 = affine_map<(d0)[s0] -> (d0 * 110 + 110, s0 floordiv 3)>
+
+#set0 = affine_set<(d0, d1)[s0, s1] : (d0 * -16 + s0 - 16 >= 0, d1 * -3 + s1 - 3 >= 0)>
+
+// CHECK-LABEL: func @perfect_if_else
+func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 : index,
+            %arg5 : index, %arg6 : index, %sym : index) {
+  affine.for %arg7 = #lb0(%arg5) to min #ub0(%arg5)[%sym] {
+    affine.parallel (%i0, %j0) = (0, 0) to (symbol(%sym), 100) step (10, 10) {
+      affine.for %arg8 = #lb1(%arg4) to min #ub1(%arg4)[%sym] {
+        affine.if #set0(%arg6, %arg7)[%sym, %sym] {
+          affine.for %arg9 = #flb0(%arg6) to #fub0(%arg6) {
+            affine.for %arg10 = #flb1(%arg7) to #fub1(%arg7) {
+              affine.load %arg0[0, 0] : memref<?x?xf64>
+            }
+          }
+        } else {
+          affine.for %arg9 = #lb0(%arg6) to min #pub0(%arg6)[%sym] {
+            affine.for %arg10 = #lb1(%arg7) to min #pub1(%arg7)[%sym] {
+              affine.load %arg0[0, 0] : memref<?x?xf64>
+            }
+          }
+        }
+      }
+    }
+  }
+  return
+}
+
+// CHECK:       affine.for
+// CHECK-NEXT:    affine.if
+// CHECK-NEXT:      affine.parallel
+// CHECK-NEXT:        affine.for
+// CHECK-NEXT:          affine.for
+// CHECK-NEXT:            affine.for
+// CHECK-NEXT:              affine.load
+// CHECK-NEXT:            }
+// CHECK-NEXT:          }
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    } else {
+// CHECK-NEXT:      affine.parallel
+// CHECK-NEXT:        affine.for
+// CHECK-NEXT:          affine.for
+// CHECK-NEXT:            affine.for
+// CHECK-NEXT:              affine.load
+// CHECK-NEXT:            }
+// CHECK-NEXT:          }
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+
+// With multiple if ops in a function, the test pass just looks for the first if
+// op that it is able to successfully hoist.
+
+// CHECK-LABEL: func @multiple_if
+func @multiple_if(%N : index) {
+  affine.if affine_set<() : (0 == 0)>() {
+    call @external() : () -> ()
+  }
+  affine.for %i = 0 to 100 {
+    affine.if affine_set<()[s0] : (s0 >= 0)>()[%N] {
+      call @external() : () -> ()
+    }
+  }
+  return
+}
+// CHECK:      affine.if
+// CHECK-NEXT:   call
+// CHECK-NEXT: }
+// CHECK-NEXT: affine.if
+// CHECK-NEXT:   affine.for
+// CHECK-NEXT:     call
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+func @external()

diff  --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index fb1346a8fb5b..296e431e5414 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_library(MLIRAffineTransformsTestPasses
   TestAffineDataCopy.cpp
+  TestAffineLoopUnswitching.cpp
   TestLoopPermutation.cpp
   TestParallelismDetection.cpp
   TestVectorizationUtils.cpp

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
new file mode 100644
index 000000000000..69ca5ce96da1
--- /dev/null
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
@@ -0,0 +1,60 @@
+//===- TestAffineLoopUnswitching.cpp - Test affine if/else hoisting -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to hoist affine if/else structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#define PASS_NAME "test-affine-loop-unswitch"
+
+using namespace mlir;
+
+namespace {
+
+/// This pass applies the permutation on the first maximal perfect nest.
+struct TestAffineLoopUnswitching
+    : public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
+  TestAffineLoopUnswitching() = default;
+  TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
+
+  void runOnFunction() override;
+
+  /// The maximum number of iterations to run this for.
+  constexpr static unsigned kMaxIterations = 5;
+};
+
+} // end anonymous namespace
+
+void TestAffineLoopUnswitching::runOnFunction() {
+  // Each hoisting invalidates a lot of IR around. Just stop the walk after the
+  // first if/else hoisting, and repeat until no more hoisting can be done, or
+  // the maximum number of iterations have been run.
+  auto func = getFunction();
+  unsigned i = 0;
+  do {
+    auto walkFn = [](AffineIfOp op) {
+      return succeeded(hoistAffineIfOp(op)) ? WalkResult::interrupt()
+                                            : WalkResult::advance();
+    };
+    if (func.walk(walkFn).wasInterrupted())
+      break;
+  } while (++i < kMaxIterations);
+}
+
+namespace mlir {
+void registerTestAffineLoopUnswitchingPass() {
+  PassRegistration<TestAffineLoopUnswitching>(
+      PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index fe92a84103aa..fe8ae83a8154 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -40,6 +40,7 @@ void registerSimpleParametricTilingPass();
 void registerSymbolTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAllReduceLoweringPass();
+void registerTestAffineLoopUnswitchingPass();
 void registerTestLinalgMatmulToVectorPass();
 void registerTestLoopPermutationPass();
 void registerTestCallGraphPass();
@@ -103,6 +104,7 @@ void registerTestPasses() {
   registerSymbolTestPasses();
   registerTestAffineDataCopyPass();
   registerTestAllReduceLoweringPass();
+  registerTestAffineLoopUnswitchingPass();
   registerTestLinalgMatmulToVectorPass();
   registerTestLoopPermutationPass();
   registerTestCallGraphPass();


        


More information about the Mlir-commits mailing list