[Mlir-commits] [mlir] 912fedf - [mlir][affine][NFC] Split `reifyValueBound` in two functions
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 18 00:41:06 PDT 2023
Author: Matthias Springer
Date: 2023-04-18T16:40:56+09:00
New Revision: 912fedfbe5b9a2120c2a48a142abaff3bdc7a971
URL: https://github.com/llvm/llvm-project/commit/912fedfbe5b9a2120c2a48a142abaff3bdc7a971
DIFF: https://github.com/llvm/llvm-project/commit/912fedfbe5b9a2120c2a48a142abaff3bdc7a971.diff
LOG: [mlir][affine][NFC] Split `reifyValueBound` in two functions
There are now two entry points. One for shaped values and one for index-typed values. This addresses a comment in D146524.
Differential Revision: https://reviews.llvm.org/D147987
Added:
Modified:
mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index db73729346f7b..70928813bab87 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -49,20 +49,9 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
-/// Reify a bound for the given index-typed value or shape dimension size in
-/// terms of the owning op's operands. `dim` must be `nullopt` if and only if
-/// `value` is index-typed.
-///
-/// By default, lower/equal bounds are closed and upper bounds are open. If
-/// `closedUB` is set to "true", upper bounds are also closed.
-FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- bool closedUB = false);
-
-/// Reify a bound for the given index-typed value or shape dimension size in
-/// terms of SSA values for which `stopCondition` is met. `dim` must be
-/// `nullopt` if and only if `value` is index-typed.
+/// Reify a bound for the given index-typed value in terms of SSA values for
+/// which `stopCondition` is met. If no stop condition is specified, reify in
+/// terms of the operands of the owner op.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
@@ -77,11 +66,22 @@ FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
/// is an EQ bound for %1.
/// * Otherwise, if the owners of %a, %b or %c do not implement the
/// ValueBoundsOpInterface, no bound can be computed.
-FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB = false);
+FailureOr<OpFoldResult> reifyIndexValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr,
+ bool closedUB = false);
+
+/// Reify a bound for the specified dimension of the given shaped value in terms
+/// of SSA values for which `stopCondition` is met. If no stop condition is
+/// specified, reify in terms of the operands of the owner op.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult> reifyShapedValueDimBound(
+ OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+ int64_t dim,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr,
+ bool closedUB = false);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 2a32874c1891e..3fc0fedd89bbf 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -15,25 +15,11 @@
using namespace mlir;
-FailureOr<OpFoldResult>
-mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, bool closedUB) {
- // We are trying to reify a bound for `value`. Construct a stop condition that
- // evaluates to "true" for any SSA value expect for `value`. I.e., the bound
- // will be computed in terms of any SSA values except for `value`. The first
- // such values are operands of the owner of `value`.
- auto stopCondition = [&](Value v, std::optional<int64_t> d) {
- // Reify in terms of SSA values that are
diff erent from `value`.
- return v != value;
- };
- return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB);
-}
-
-FailureOr<OpFoldResult> mlir::reifyValueBound(
- OpBuilder &b, Location loc, presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- function_ref<bool(Value, std::optional<int64_t>)> stopCondition,
- bool closedUB) {
+static FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ Value value, std::optional<int64_t> dim,
+ function_ref<bool(Value, std::optional<int64_t>)> stopCondition,
+ bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
@@ -85,3 +71,31 @@ FailureOr<OpFoldResult> mlir::reifyValueBound(
return static_cast<OpFoldResult>(
b.create<AffineApplyOp>(loc, boundMap, operands).getResult());
}
+
+FailureOr<OpFoldResult> mlir::reifyShapedValueDimBound(
+ OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+ int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
+ bool closedUB) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ // We are trying to reify a bound for `value` in terms of the owning op's
+ // operands. Construct a stop condition that evaluates to "true" for any SSA
+ // value except for `value`. I.e., the bound will be computed in terms of
+ // any SSA values except for `value`. The first such values are operands of
+ // the owner of `value`.
+ return v != value;
+ };
+ return reifyValueBound(b, loc, type, value, dim,
+ stopCondition ? stopCondition : reifyToOperands,
+ closedUB);
+}
+
+FailureOr<OpFoldResult> mlir::reifyIndexValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type, Value value,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ return v != value;
+ };
+ return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ stopCondition ? stopCondition : reifyToOperands,
+ closedUB);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 7a6c58aee5938..f61f23eb502a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -462,9 +462,9 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
// of the enclosing loops.
for (auto forOp : packingLoops) {
// Compute an upper bound `ubVal` for the upper bound of `forOp`.
- FailureOr<OpFoldResult> loopUb = reifyValueBound(
+ FailureOr<OpFoldResult> loopUb = reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
- /*dim=*/std::nullopt, /*stopCondition=*/
+ /*stopCondition=*/
[&](Value v, std::optional<int64_t> d) {
if (v == forOp.getUpperBound())
return false;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 922c6571e1bdf..d8ff841d3680d 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -130,8 +130,13 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
reified =
FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else {
- reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
- dim, stopCondition);
+ if (dim) {
+ reified = reifyShapedValueDimBound(rewriter, op->getLoc(), *boundType,
+ value, *dim, stopCondition);
+ } else {
+ reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType,
+ value, stopCondition);
+ }
}
if (failed(reified)) {
op->emitOpError("could not reify bound");
More information about the Mlir-commits
mailing list