[Mlir-commits] [mlir] [MLIR] Support interupting AffineExpr walks (PR #74792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 7 17:02:41 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Uday Bondhugula (bondhugula)
<details>
<summary>Changes</summary>
Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.
---
Full diff: https://github.com/llvm/llvm-project/pull/74792.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/AffineExpr.h (+23-2)
- (modified) mlir/include/mlir/IR/AffineExprVisitor.h (+48-8)
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+29-31)
- (modified) mlir/lib/IR/AffineExpr.cpp (+28-13)
``````````diff
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a..181a24472473a 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
#ifndef MLIR_IR_AFFINEEXPR_H
#define MLIR_IR_AFFINEEXPR_H
+#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
@@ -123,8 +124,19 @@ class AffineExpr {
/// Return true if the affine expression involves AffineSymbolExpr `position`.
bool isFunctionOfSymbol(unsigned position) const;
- /// Walk all of the AffineExpr's in this expression in postorder.
- void walk(std::function<void(AffineExpr)> callback) const;
+ /// Walk all of the AffineExpr's in this expression in postorder. This allows
+ /// a lambda walk function that can either return `void` or a WalkResult. With
+ /// a WalkResult, interrupting is supported.
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, void>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<void>(*this, callback);
+ }
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<WalkResult>(*this, callback);
+ }
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +214,15 @@ class AffineExpr {
protected:
ImplType *expr{nullptr};
+
+private:
+ /// A trampoline for the templated non-static AffineExpr::walk method to
+ /// dispatch lambda `callback`'s of either a void result type or a
+ /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+ /// should use the regular (non-static) `walk` method.
+ template <typename WalkRetTy>
+ static WalkRetTy walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback);
};
/// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f428..5b3663d1dea7e 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
@@ -65,7 +68,6 @@ namespace mlir {
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
-
template <typename SubClass, typename RetTy>
class AffineExprVisitorBase {
public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
};
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
template <typename SubClass, typename RetTy = void>
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
- walkPostOrder(expr.getLHS());
- walkPostOrder(expr.getRHS());
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkPostOrder(expr.getLHS());
+ }
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ } else {
+ walkPostOrder(expr.getRHS());
+ }
}
};
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e..578d03c629285 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
static bool
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
- SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
- MLIRContext *context) {
- bool isDynamicDim = false;
+ SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
AffineExpr expr = layoutMap.getResults()[dim];
// Check if affine expr of the dimension includes dynamic dimension of input
// memrefType.
- expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
- if (isa<AffineDimExpr>(e)) {
- for (unsigned dm : inMemrefTypeDynDims) {
- if (e == getAffineDimExpr(dm, context)) {
- isDynamicDim = true;
- }
- }
- }
- });
- return isDynamicDim;
+ MLIRContext *context = layoutMap.getContext();
+ return expr
+ .walk([&](AffineExpr e) {
+ if (isa<AffineDimExpr>(e) &&
+ llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+ return e == getAffineDimExpr(dim, context);
+ }))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
}
/// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
MLIRContext *context = memrefType.getContext();
for (unsigned d = 0; d < newRank; ++d) {
// Check if this dimension is dynamic.
- bool isDynDim =
- isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
- if (isDynDim) {
+ if (bool isDynDim =
+ isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
newShape[d] = ShapedType::kDynamic;
- } else {
- // The lower bound for the shape is always zero.
- std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
- // For a static memref and an affine map with no symbols, this is
- // always bounded. However, when we have symbols, we may not be able to
- // obtain a constant upper bound. Also, mapping to a negative space is
- // invalid for normalization.
- if (!ubConst.has_value() || *ubConst < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "can't normalize map due to unknown/invalid upper bound");
- return memrefType;
- }
- // If dimension of new memrefType is dynamic, the value is -1.
- newShape[d] = *ubConst + 1;
+ continue;
+ }
+ // The lower bound for the shape is always zero.
+ std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+ // For a static memref and an affine map with no symbols, this is
+ // always bounded. However, when we have symbols, we may not be able to
+ // obtain a constant upper bound. Also, mapping to a negative space is
+ // invalid for normalization.
+ if (!ubConst.has_value() || *ubConst < 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't normalize map due to unknown/invalid upper bound");
+ return memrefType;
}
+ // If dimension of new memrefType is dynamic, the value is -1.
+ newShape[d] = *ubConst + 1;
}
// Create the new memref type after trivializing the old layout map.
- MemRefType newMemRefType =
+ auto newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a36..a90b264a8edd2 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
- struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
- std::function<void(AffineExpr)> callback;
-
- AffineExprWalker(std::function<void(AffineExpr)> callback)
- : callback(std::move(callback)) {}
-
- void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
- void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
- void visitDimExpr(AffineDimExpr expr) { callback(expr); }
- void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback) {
+ struct AffineExprWalker
+ : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+ function_ref<WalkRetTy(AffineExpr)> callback;
+
+ AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+ : callback(callback) {}
+
+ WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+ WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
};
- AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+ return AffineExprWalker(callback).walkPostOrder(e);
}
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkResult(AffineExpr)> callback);
// Dispatch affine expression construction based on kind.
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
``````````
</details>
https://github.com/llvm/llvm-project/pull/74792
More information about the Mlir-commits
mailing list