[Mlir-commits] [mlir] [MLIR] Support interrupting AffineExpr walks (PR #74792)

Uday Bondhugula llvmlistbot at llvm.org
Fri Dec 8 17:46:13 PST 2023


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/74792

>From cba37319e8ef945d12309f05737d856c8b1da00c Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Tue, 5 Dec 2023 22:25:48 +0530
Subject: [PATCH] [MLIR] Support interrupting AffineExpr walks

Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.
---
 mlir/include/mlir/IR/AffineExpr.h        | 19 +++++++-
 mlir/include/mlir/IR/AffineExprVisitor.h | 56 ++++++++++++++++++----
 mlir/lib/Dialect/Affine/Utils/Utils.cpp  | 60 ++++++++++++------------
 mlir/lib/IR/AffineExpr.cpp               | 41 +++++++++++-----
 4 files changed, 122 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a..63314cc756355 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_IR_AFFINEEXPR_H
 #define MLIR_IR_AFFINEEXPR_H
 
+#include "mlir/IR/Visitors.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/Hashing.h"
@@ -123,8 +124,13 @@ class AffineExpr {
   /// Return true if the affine expression involves AffineSymbolExpr `position`.
   bool isFunctionOfSymbol(unsigned position) const;
 
-  /// Walk all of the AffineExpr's in this expression in postorder.
-  void walk(std::function<void(AffineExpr)> callback) const;
+  /// Walk all of the AffineExpr's in this expression in postorder. This allows
+  /// a lambda walk function that can either return `void` or a WalkResult. With
+  /// a WalkResult, interrupting is supported.
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  RetT walk(FnT &&callback) const {
+    return walk<RetT>(*this, callback);
+  }
 
   /// This method substitutes any uses of dimensions and symbols (e.g.
   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +208,15 @@ class AffineExpr {
 
 protected:
   ImplType *expr{nullptr};
+
+private:
+  /// A trampoline for the templated non-static AffineExpr::walk method to
+  /// dispatch lambda `callback`'s of either a void result type or a
+  /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+  /// should use the regular (non-static) `walk` method.
+  template <typename WalkRetTy>
+  static WalkRetTy walk(AffineExpr e,
+                        function_ref<WalkRetTy(AffineExpr)> callback);
 };
 
 /// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f428..3e1bbb4b3fa0e 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
 /// functions in your class. This class is defined in terms of statically
 /// resolved overloading, not virtual functions.
 ///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
 /// For example, here is a visitor that counts the number of for AffineDimExprs
 /// in an AffineExpr.
 ///
@@ -65,7 +68,6 @@ namespace mlir {
 /// virtual function call overhead. Defining and using a AffineExprVisitor is
 /// just as efficient as having your own switch instruction over the instruction
 /// opcode.
-
 template <typename SubClass, typename RetTy>
 class AffineExprVisitorBase {
 public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
   RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
 };
 
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
 template <typename SubClass, typename RetTy = void>
 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
   //===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
 private:
   // Walk the operands - each operand is itself walked in post order.
   RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
-    walkPostOrder(expr.getLHS());
-    walkPostOrder(expr.getRHS());
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+    } else {
+      walkPostOrder(expr.getLHS());
+    }
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    } else {
+      return walkPostOrder(expr.getRHS());
+    }
   }
 };
 
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e..578d03c629285 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
 static bool
 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
-                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
-                             MLIRContext *context) {
-  bool isDynamicDim = false;
+                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
   AffineExpr expr = layoutMap.getResults()[dim];
   // Check if affine expr of the dimension includes dynamic dimension of input
   // memrefType.
-  expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
-    if (isa<AffineDimExpr>(e)) {
-      for (unsigned dm : inMemrefTypeDynDims) {
-        if (e == getAffineDimExpr(dm, context)) {
-          isDynamicDim = true;
-        }
-      }
-    }
-  });
-  return isDynamicDim;
+  MLIRContext *context = layoutMap.getContext();
+  return expr
+      .walk([&](AffineExpr e) {
+        if (isa<AffineDimExpr>(e) &&
+            llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+              return e == getAffineDimExpr(dim, context);
+            }))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
 }
 
 /// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   MLIRContext *context = memrefType.getContext();
   for (unsigned d = 0; d < newRank; ++d) {
     // Check if this dimension is dynamic.
-    bool isDynDim =
-        isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
-    if (isDynDim) {
+    if (bool isDynDim =
+            isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
       newShape[d] = ShapedType::kDynamic;
-    } else {
-      // The lower bound for the shape is always zero.
-      std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
-      // For a static memref and an affine map with no symbols, this is
-      // always bounded. However, when we have symbols, we may not be able to
-      // obtain a constant upper bound. Also, mapping to a negative space is
-      // invalid for normalization.
-      if (!ubConst.has_value() || *ubConst < 0) {
-        LLVM_DEBUG(llvm::dbgs()
-                   << "can't normalize map due to unknown/invalid upper bound");
-        return memrefType;
-      }
-      // If dimension of new memrefType is dynamic, the value is -1.
-      newShape[d] = *ubConst + 1;
+      continue;
+    }
+    // The lower bound for the shape is always zero.
+    std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+    // For a static memref and an affine map with no symbols, this is
+    // always bounded. However, when we have symbols, we may not be able to
+    // obtain a constant upper bound. Also, mapping to a negative space is
+    // invalid for normalization.
+    if (!ubConst.has_value() || *ubConst < 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "can't normalize map due to unknown/invalid upper bound");
+      return memrefType;
     }
+    // If dimension of new memrefType is dynamic, the value is -1.
+    newShape[d] = *ubConst + 1;
   }
 
   // Create the new memref type after trivializing the old layout map.
-  MemRefType newMemRefType =
+  auto newMemRefType =
       MemRefType::Builder(memrefType)
           .setShape(newShape)
           .setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a36..a90b264a8edd2 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
 
 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
 
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
-  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
-    std::function<void(AffineExpr)> callback;
-
-    AffineExprWalker(std::function<void(AffineExpr)> callback)
-        : callback(std::move(callback)) {}
-
-    void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
-    void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
-    void visitDimExpr(AffineDimExpr expr) { callback(expr); }
-    void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+                                 function_ref<WalkRetTy(AffineExpr)> callback) {
+  struct AffineExprWalker
+      : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+    function_ref<WalkRetTy(AffineExpr)> callback;
+
+    AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+        : callback(callback) {}
+
+    WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+    WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
   };
 
-  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+  return AffineExprWalker(callback).walkPostOrder(e);
 }
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+                                     function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+                       function_ref<WalkResult(AffineExpr)> callback);
 
 // Dispatch affine expression construction based on kind.
 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,



More information about the Mlir-commits mailing list