[Mlir-commits] [mlir] 0dc9087 - [mlir][Interfaces] ValueBoundsOpInterface: Compute constant bounds
Matthias Springer
llvmlistbot at llvm.org
Thu Apr 6 19:39:22 PDT 2023
Author: Matthias Springer
Date: 2023-04-07T11:35:02+09:00
New Revision: 0dc9087ac752659d29fdea6b9fabdd1b7987c996
URL: https://github.com/llvm/llvm-project/commit/0dc9087ac752659d29fdea6b9fabdd1b7987c996
DIFF: https://github.com/llvm/llvm-project/commit/0dc9087ac752659d29fdea6b9fabdd1b7987c996.diff
LOG: [mlir][Interfaces] ValueBoundsOpInterface: Compute constant bounds
Add a helper function that computes a constant (`int64_t`) bound. The `stopCondition` is optional: If none is provided, the traversal continues until a constant bound could be computed.
Differential Revision: https://reviews.llvm.org/D146296
Added:
Modified:
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Affine/value-bounds-reification.mlir
mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 97d27a04df893..a4a7c98ae3e01 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -16,6 +16,8 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include <queue>
+
namespace mlir {
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
@@ -100,6 +102,24 @@ class ValueBoundsConstraintSet {
std::optional<int64_t> dim,
StopConditionFn stopCondition);
+ /// Compute a constant bound for the given index-typed value or shape
+ /// dimension size.
+ ///
+ /// `dim` must be `nullopt` if and only if `value` is index-typed. This
+ /// function traverses the backward slice of the given value in a
+ /// worklist-driven manner until `stopCondition` evaluates to "true". The
+ /// constraint set is populated according to `ValueBoundsOpInterface` for each
+ /// visited value. (No constraints are added for values for which the stop
+ /// condition evaluates to "true".)
+ ///
+ /// The stop condition is optional: If none is specified, the backward slice
+ /// is traversed in a breadth-first manner until a constant bound could be
+ /// computed.
+ static FailureOr<int64_t>
+ computeConstantBound(presburger::BoundType type, Value value,
+ std::optional<int64_t> dim = std::nullopt,
+ StopConditionFn stopCondition = nullptr);
+
/// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound.
BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
@@ -162,7 +182,7 @@ class ValueBoundsConstraintSet {
DenseMap<ValueDim, int64_t> valueDimToPosition;
/// Worklist of values/shape dimensions that have not been processed yet.
- SetVector<int64_t> worklist;
+ std::queue<int64_t> worklist;
/// Constraint system of equalities and inequalities.
FlatLinearConstraints cstr;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 8db5e6865646b..a2885e22d01bb 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -121,7 +121,7 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
valueDimToPosition[positionToValueDim[i]] = i;
- worklist.insert(pos);
+ worklist.push(pos);
return pos;
}
@@ -148,7 +148,8 @@ static Operation *getOwnerOfValue(Value value) {
void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
while (!worklist.empty()) {
- int64_t pos = worklist.pop_back_val();
+ int64_t pos = worklist.front();
+ worklist.pop();
ValueDim valueDim = positionToValueDim[pos];
Value value = valueDim.first;
int64_t dim = valueDim.second;
@@ -337,6 +338,33 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return success();
}
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType type, Value value, std::optional<int64_t> dim,
+ StopConditionFn stopCondition) {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+ // Process the backward slice of `value` (i.e., reverse use-def chain) until
+ // `stopCondition` is met.
+ ValueBoundsConstraintSet cstr(value, dim);
+ int64_t pos = cstr.getPos(value, dim);
+ if (stopCondition) {
+ cstr.processWorklist(stopCondition);
+ } else {
+ // No stop condition specified: Keep adding constraints until a bound could
+ // be computed.
+ cstr.processWorklist(/*stopCondition=*/[&](Value v) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ });
+ }
+
+ // Compute constant bound for `valueDim`.
+ if (auto bound = cstr.cstr.getConstantBound64(type, pos))
+ return type == BoundType::UB ? *bound + 1 : *bound;
+ return failure();
+}
+
ValueBoundsConstraintSet::BoundBuilder &
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
assert(!this->dim.has_value() && "dim was already set");
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index e5ee497e8e205..5b4d1f2f42c2f 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -26,6 +26,8 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
// CHECK-LABEL: func @reify_slice_bound(
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
+// CHECK: %[[c5:.*]] = arith.constant 5 : index
+// CHECK: "test.some_use"(%[[c5]])
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
@@ -33,8 +35,12 @@ func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
+
+ %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+ "test.some_use"(%bound_const) : (index) -> ()
}
return
}
@@ -77,6 +83,11 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
%lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index)
"test.some_use"(%lb1_ub) : (index) -> ()
+ // CHECK: %[[c129:.*]] = arith.constant 129 : index
+ // CHECK: "test.some_use"(%[[c129]])
+ %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index)
+ "test.some_use"(%lb1_ub_const) : (index) -> ()
+
scf.for %iv1 = %lb1 to %ub1 step %c32 {
// CHECK: %[[c32:.*]] = arith.constant 32 : index
// CHECK: "test.some_use"(%[[c32]])
@@ -94,6 +105,11 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
// CHECK: "test.some_use"(%[[c32]])
%matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%matmul_ub) : (index) -> ()
+
+ // CHECK: %[[c32:.*]] = arith.constant 32 : index
+ // CHECK: "test.some_use"(%[[c32]])
+ %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+ "test.some_use"(%matmul_ub_const) : (index) -> ()
}
}
}
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 576759e4f21ca..614c6014fec98 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -80,6 +80,27 @@ func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
// -----
+func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> index {
+ %0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
+ // expected-error @below{{could not reify bound}}
+ %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_slice_static_constant(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: %[[c5:.*]] = arith.constant 5 : index
+// CHECK: return %[[c5]]
+func.func @extract_slice_static_constant(%t: tensor<?xf32>) -> index {
+ %0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
+ %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_slice_rank_reduce(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[sz:.*]]: index
// CHECK: return %[[sz]]
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index e2a06a51f4636..7f66db3b39993 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -63,7 +63,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
IRRewriter rewriter(funcOp.getContext());
WalkResult result = funcOp.walk([&](Operation *op) {
// Look for test.reify_bound ops.
- if (op->getName().getStringRef() == "test.reify_bound") {
+ if (op->getName().getStringRef() == "test.reify_bound" ||
+ op->getName().getStringRef() == "test.reify_constant_bound") {
if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
!op->getResultTypes()[0].isIndex()) {
op->emitOpError("invalid op");
@@ -94,22 +95,37 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
: std::make_optional<int64_t>(
op->getAttrOfType<IntegerAttr>("dim").getInt());
- // Reify value bound.
- rewriter.setInsertionPointAfter(op);
- FailureOr<OpFoldResult> reified;
- if (!reifyToFuncArgs) {
- // Reify in terms of the op's operands.
- reified =
- reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim);
- } else {
+ // Check if a constant was requested.
+ bool constant =
+ op->getName().getStringRef() == "test.reify_constant_bound";
+
+ // Prepare stop condition. By default, reify in terms of the op's
+ // operands. No stop condition is used when a constant was requested.
+ std::function<bool(Value)> stopCondition = [&](Value v) {
+ // Reify in terms of SSA values that are
diff erent from `value`.
+ return v != value;
+ };
+ if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
- auto stopCondition = [](Value v) {
+ stopCondition = stopCondition = [](Value v) {
auto bbArg = v.dyn_cast<BlockArgument>();
if (!bbArg)
return false;
return isa<FunctionOpInterface>(
bbArg.getParentBlock()->getParentOp());
};
+ }
+
+ // Reify value bound
+ rewriter.setInsertionPointAfter(op);
+ FailureOr<OpFoldResult> reified = failure();
+ if (constant) {
+ auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
+ *boundType, value, dim, /*stopCondition=*/nullptr);
+ if (succeeded(reifiedConst))
+ reified =
+ FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+ } else {
reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
dim, stopCondition);
}
More information about the Mlir-commits
mailing list