[llvm-branch-commits] [mlir] [mlir][Interfaces][WIP] `Variable` abstraction for `ValueBoundsOpInterface` (PR #87980)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Apr 8 06:54:04 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/87980
>From 183ab14683a335e65891e1d585f69231699efceb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 8 Apr 2024 11:25:29 +0000
Subject: [PATCH] [mlir][Interfaces][WIP] `ValueBoundsOpInterface`: `Variable`
---
.../mlir/Interfaces/ValueBoundsOpInterface.h | 117 +++---
.../Affine/IR/ValueBoundsOpInterfaceImpl.cpp | 6 +-
.../Affine/Transforms/ReifyValueBounds.cpp | 2 +-
.../Arith/IR/ValueBoundsOpInterfaceImpl.cpp | 8 +-
.../Dialect/Arith/Transforms/IntNarrowing.cpp | 2 +-
.../Arith/Transforms/ReifyValueBounds.cpp | 4 +-
.../lib/Dialect/Linalg/Transforms/Padding.cpp | 6 +-
.../Dialect/Linalg/Transforms/Promotion.cpp | 6 +-
.../Transforms/IndependenceTransforms.cpp | 5 +-
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 17 +-
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 3 +-
.../Transforms/IndependenceTransforms.cpp | 3 +-
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 4 +-
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 337 ++++++++----------
.../Dialect/Affine/TestReifyValueBounds.cpp | 6 +-
15 files changed, 247 insertions(+), 279 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3..3e1502b4f5c357 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
public:
static char ID;
+ /// A variable that can be added to the constraint set as a "column". The
+ /// value bounds infrastructure can compute bounds for variables and compare
+ /// two variables.
+ ///
+ /// Internally, a variable is represented as an affine map and operands.
+ class Variable {
+ public:
+ /// Construct a variable for an index-typed attribute or SSA value.
+ Variable(OpFoldResult ofr);
+
+ /// Construct a variable for an index-typed SSA value.
+ Variable(Value indexValue);
+
+ /// Construct a variable for a dimension of a shaped value.
+ Variable(Value shapedValue, int64_t dim);
+
+ /// Construct a variable for an index-typed attribute/SSA value or for a
+ /// dimension of a shaped value. A non-null dimension must be provided if
+ /// and only if `ofr` is a shaped value.
+ Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+ /// Construct a variable for a map and its operands.
+ Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+ Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+ MLIRContext *getContext() const { return map.getContext(); }
+
+ private:
+ friend class ValueBoundsConstraintSet;
+ AffineMap map;
+ ValueDimList mapOperands;
+ };
+
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
- /// Compute a bound for the given index-typed value or shape dimension size.
- /// The computed bound is stored in `resultMap`. The operands of the bound are
- /// stored in `mapOperands`. An operand is either an index-type SSA value
- /// or a shaped value and a dimension.
+ /// Compute a bound for the given variable. The computed bound is stored in
+ /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+ /// operand is either an index-type SSA value or a shaped value and a
+ /// dimension.
///
- /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
- /// is computed in terms of values/dimensions for which `stopCondition`
- /// evaluates to "true". To that end, the backward slice (reverse use-def
- /// chain) of the given value is visited in a worklist-driven manner and the
- /// constraint set is populated according to `ValueBoundsOpInterface` for each
- /// visited value.
+ /// The bound is computed in terms of values/dimensions for which
+ /// `stopCondition` evaluates to "true". To that end, the backward slice
+ /// (reverse use-def chain) of the given value is visited in a worklist-driven
+ /// manner and the constraint set is populated according to
+ /// `ValueBoundsOpInterface` for each visited value.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
- static LogicalResult computeBound(AffineMap &resultMap,
- ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- StopConditionFn stopCondition,
- bool closedUB = false);
+ static LogicalResult
+ computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+ presburger::BoundType type, const Variable &var,
+ StopConditionFn stopCondition, bool closedUB = false);
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueDimList dependencies, bool closedUB = false);
/// Compute a bound in that is independent of all values in `independencies`.
///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
/// appear in the computed bound.
static LogicalResult
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueRange independencies, bool closedUB = false);
- /// Compute a constant bound for the given affine map, where dims and symbols
- /// are bound to the given operands. The affine map must have exactly one
- /// result.
+ /// Compute a constant bound for the given variable.
///
/// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
- computeConstantBound(presburger::BoundType type, Value value,
- std::optional<int64_t> dim = std::nullopt,
+ computeConstantBound(presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
+ bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
- static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
- static bool compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands);
- static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands);
-
- /// Compute whether the given values/dimensions are equal. Return "failure" if
+ static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+ /// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
- ///
- /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
- /// index-typed.
- static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1 = std::nullopt,
- std::optional<int64_t> dim2 = std::nullopt);
+ static FailureOr<bool> areEqual(Variable var1, Variable var2);
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
- bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+ int64_t insert(const Variable &var, bool isSymbol = true);
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
/// Project out all columns for which the condition holds.
void projectOut(function_ref<bool(ValueDim)> condition);
+ void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
/// Mapping of columns to values/shape dimensions.
SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d..82a9fb0d490882 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
- ValueDimList valueDims;
- for (Value v : mapOperands)
- valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::EQ, map, valueDims);
+ presburger::BoundType::EQ,
+ ValueBoundsConstraintSet::Variable(map, mapOperands));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7..6c59df91e8af78 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
return failure();
// Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d43808bc45df..7cfcc4180539c2 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
- if (cstr.compare(trueValue, dim,
+ if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- falseValue, dim)) {
+ /*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
- if (cstr.compare(falseValue, dim,
+ if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- trueValue, dim)) {
+ /*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a..f87f3d6350c022 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190e..5bb7d83bf1e3f8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type,
+ ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+ closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db248989..518d2e138c02a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, opOperand->get(),
- /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB,
+ {opOperand->get(),
+ /*dim=*/i},
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049..71eb59d40836c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
- Value materializedSize =
- getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
- ? materializedSize
+ ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9..1f06318cbd60e0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueRange independencies) {
if (ofr.is<Attribute>())
return ofr;
- Value value = ofr.get<Value>();
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
- boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+ /*closedUB=*/true)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a..17a1c016ea16d5 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.populateAndCompare(
- yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
- iterArg, dim)) {
+ /*lhs=*/{yieldedValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ /*rhs=*/{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
- cstr.bound(value) == initArg;
+ cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
@@ -113,8 +114,9 @@ struct IfOpInterface
// * result <= elseValue
// * result >= thenValue
if (cstr.populateAndCompare(
- thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ /*lhs=*/{thenValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
// * result <= thenValue
// * result >= elseValue
if (cstr.populateAndCompare(
- elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ /*lhs=*/{elseValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c13..d25efcf50ec566 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
info.isAlignedToInnerTileSize = false;
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB,
- getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+ presburger::BoundType::UB, tileSize,
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49b3..a89ce20048dff3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ independencies,
+ /*closedUB=*/true)))
return failure();
return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a1700..15381ec520e211 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), srcDim, resultDim);
+ {op.getSource(), srcDim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), dim, resultDim);
+ {op.getSource(), dim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++resultDim;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index fa66da4a0def93..9f220f5f6ceb72 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir
+static Operation *getOwnerOfValue(Value value) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
+ return bbArg.getOwner()->getParentOp();
+ return value.getDefiningOp();
+}
+
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+ : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+ : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+ : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+ std::optional<int64_t> dim) {
+ Builder b(ofr.getContext());
+ if (auto constInt = ::getConstantIntValue(ofr)) {
+ assert(!dim && "expected no dim for index-typed values");
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+ b.getAffineConstantExpr(*constInt));
+ return;
+ }
+ Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+ if (dim) {
+ assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+ } else {
+ assert(value.getType().isIndex() && "expected index type");
+ }
+#endif // NDEBUG
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+ b.getAffineSymbolExpr(0));
+ mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Variable> mapOperands) {
+ assert(map.getNumResults() == 1 && "expected single result");
+
+ // Turn all dims into symbols.
+ Builder b(map.getContext());
+ SmallVector<AffineExpr> dimReplacements, symReplacements;
+ for (int64_t i = 0; i < map.getNumDims(); ++i)
+ dimReplacements.push_back(b.getAffineSymbolExpr(i));
+ for (int64_t i = 0; i < map.getNumSymbols(); ++i)
+ symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+ AffineMap tmpMap = map.replaceDimsAndSymbols(
+ dimReplacements, symReplacements, /*numResultDims=*/0,
+ /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+ // Inline operands.
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ for (auto [index, var] : llvm::enumerate(mapOperands)) {
+ assert(var.map.getNumResults() == 1 && "expected single result");
+ assert(var.map.getNumDims() == 0 && "expected only symbols");
+ SmallVector<AffineExpr> symReplacements;
+ for (auto valueDim : var.mapOperands) {
+ auto it = llvm::find(this->mapOperands, valueDim);
+ if (it != this->mapOperands.end()) {
+ // There is already a symbol for this operand.
+ symReplacements.push_back(b.getAffineSymbolExpr(
+ std::distance(this->mapOperands.begin(), it)));
+ } else {
+ // This is a new operand: add a new symbol.
+ symReplacements.push_back(
+ b.getAffineSymbolExpr(this->mapOperands.size()));
+ this->mapOperands.push_back(valueDim);
+ }
+ }
+ replacements[b.getAffineSymbolExpr(index)] =
+ var.map.getResult(0).replaceSymbols(symReplacements);
+ }
+ this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+ /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Value> mapOperands)
+ : Variable(map, llvm::map_to_vector(mapOperands,
+ [](Value v) { return Variable(v); })) {}
+
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
MLIRContext *ctx, StopConditionFn stopCondition)
: builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
assert(!valueDimToPosition.contains(valueDim) && "already mapped");
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+ << " for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+ << "\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+ return insert(var.map, var.mapOperands, isSymbol);
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
"unstructured control flow is not supported");
#endif // NDEBUG
-
+ LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
auto it =
valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
return it != valueDimToPosition.end();
}
-static Operation *getOwnerOfValue(Value value) {
- if (auto bbArg = dyn_cast<BlockArgument>(value))
- return bbArg.getOwner()->getParentOp();
- return value.getDefiningOp();
-}
-
void ValueBoundsConstraintSet::processWorklist() {
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
}
}
+void ValueBoundsConstraintSet::projectOutAnonymous(
+ std::optional<int64_t> except) {
+ int64_t nextPos = 0;
+ while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+ if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+ ++nextPos;
+ } else {
+ projectOut(nextPos);
+ // The column was projected out so another column is now at that position.
+ // Do not increase the counter.
+ }
+ }
+}
+
LogicalResult ValueBoundsConstraintSet::computeBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
- bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+ const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+ MLIRContext *ctx = var.getContext();
int64_t ubAdjustment = closedUB ? 0 : 1;
- Builder b(value.getContext());
+ Builder b(ctx);
mapOperands.clear();
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
- ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
- assert(!stopCondition(value, dim, cstr) &&
- "stop condition should not be satisfied for starting point");
- int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+ ValueBoundsConstraintSet cstr(ctx, stopCondition);
+ int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+ assert(pos == 0 && "expected first column");
cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
cstr.projectOut([&](ValueDim p) {
- // Do not project out `valueDim`.
- if (valueDim == p)
- return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
return !stopCondition(p.first, maybeDim, cstr);
});
+ cstr.projectOutAnonymous(/*except=*/pos);
// Compute lower and upper bounds for `valueDim`.
SmallVector<AffineMap> lb(1), ub(1);
- cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+ cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
/*closedUB=*/true);
// Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
LogicalResult ValueBoundsConstraintSet::computeDependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB) {
+ const Variable &var, ValueDimList dependencies, bool closedUB) {
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB) {
+ const Variable &var, ValueRange independencies, bool closedUB) {
// Return "true" if the given value is independent of all values in
// `independencies`. I.e., neither the value itself nor any value in the
// backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return isIndependent(v);
},
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
}
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, Value value, std::optional<int64_t> dim,
- StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
- AffineMap map =
- AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- Builder(value.getContext()).getAffineDimExpr(0));
- return computeConstantBound(type, map, {{value, dim}}, stopCondition,
- closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+ presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition, bool closedUB) {
- ValueDimList valueDims;
- for (Value v : operands) {
- assert(v.getType().isIndex() && "expected index type");
- valueDims.emplace_back(v, std::nullopt);
- }
- return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList operands,
- StopConditionFn stopCondition, bool closedUB) {
- assert(map.getNumResults() == 1 && "expected affine map with one result");
-
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
};
ValueBoundsConstraintSet cstr(
- map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
- pos = cstr.populateConstraints(map, operands);
+ var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ pos = cstr.populateConstraints(var.map, var.mapOperands);
assert(pos == 0 && "expected `map` is the first column");
// Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
Builder b(value1.getContext());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
- return computeConstantBound(presburger::BoundType::EQ, map,
- {{value1, dim1}, {value2, dim2}});
+ return computeConstantBound(presburger::BoundType::EQ,
+ Variable(map, {{value1, dim1}, {value2, dim2}}));
}
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp,
- OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+ ComparisonOperator cmp,
+ int64_t rhsPos) {
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
return false;
}
- // EQ can be expressed as LE and GE.
- if (cmp == EQ)
- return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
- compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
- // Construct inequality. For the above example: lhs > rhs.
- // `IntegerRelation` inequalities are expressed in the "flattened" form and
- // with ">= 0". I.e., lhs - rhs - 1 >= 0.
- SmallVector<int64_t> eq(cstr.getNumCols(), 0);
- auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
- int64_t factor) {
- if (auto constVal = ::getConstantIntValue(ofr)) {
- eq[cstr.getNumCols() - 1] += *constVal * factor;
- } else {
- eq[getPos(cast<Value>(ofr), dim)] += factor;
- }
- };
- if (cmp == LT || cmp == LE) {
- addToEq(lhs, lhsDim, 1);
- addToEq(rhs, rhsDim, -1);
- } else if (cmp == GT || cmp == GE) {
- addToEq(lhs, lhsDim, -1);
- addToEq(rhs, rhsDim, 1);
- } else {
- llvm_unreachable("unsupported comparison operator");
- }
- if (cmp == LE || cmp == GE)
- eq[cstr.getNumCols() - 1] -= 1;
-
- // Add inequality to the constraint set and check if it made the constraint
- // set empty.
- int64_t ineqPos = cstr.getNumInequalities();
- cstr.addInequality(eq);
- bool isEmpty = cstr.isEmpty();
- cstr.removeInequality(ineqPos);
- return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
- ComparisonOperator cmp,
- int64_t rhsPos) {
- // This function returns "true" if "lhs CMP rhs" is proven to hold. For
- // detailed documentation, see `compareValueDims`.
-
// EQ can be expressed as LE and GE.
if (cmp == EQ)
return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
@@ -712,48 +727,16 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
-bool ValueBoundsConstraintSet::populateAndCompare(
- OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
- OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
- if (auto lhsVal = dyn_cast<Value>(lhs))
- populateConstraints(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- populateConstraints(rhsVal, rhsDim);
-
- return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(Variable lhs,
+ ComparisonOperator cmp,
+ Variable rhs) {
+ int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+ int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+ return comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
- auto stopCondition = [&](Value v, std::optional<int64_t> dim,
- ValueBoundsConstraintSet &cstr) {
- // Keep processing as long as lhs/rhs are not mapped.
- if (auto lhsVal = dyn_cast<Value>(lhs))
- if (!cstr.isMapped(lhsVal, dim))
- return false;
- if (auto rhsVal = dyn_cast<Value>(rhs))
- if (!cstr.isMapped(rhsVal, dim))
- return false;
- // Keep processing as long as the relation cannot be proven.
- return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
- };
-
- ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(Variable lhs, ComparisonOperator cmp,
+ Variable rhs) {
int64_t lhsPos = -1, rhsPos = -1;
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
@@ -765,39 +748,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
return cstr.comparePos(lhsPos, cmp, rhsPos);
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- lhsPos = cstr.insert(lhs, lhsOperands);
- rhsPos = cstr.insert(rhs, rhsOperands);
- cstr.processWorklist();
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
- ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands) {
- ValueDimList lhsValueDimOperands =
- llvm::map_to_vector(lhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- ValueDimList rhsValueDimOperands =
- llvm::map_to_vector(rhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
- rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1,
- std::optional<int64_t> dim2) {
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
- value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(Variable var1,
+ Variable var2) {
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
return true;
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
- value2, dim2) ||
- ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
- value2, dim2))
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+ ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
return false;
return failure();
}
@@ -833,7 +794,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
@@ -850,7 +811,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index f38631054fb3c1..af4ba7de3df1f6 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -169,7 +169,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
FailureOr<OpFoldResult> reified = failure();
if (constant) {
auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
- *boundType, value, dim, /*stopCondition=*/nullptr);
+ *boundType, {value, dim}, /*stopCondition=*/nullptr);
if (succeeded(reifiedConst))
reified =
FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
@@ -285,8 +285,8 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
return ValueBoundsConstraintSet::compare(
- /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
- /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
+ /*lhs=*/op->getOperand(0), cmp,
+ /*rhs=*/op->getOperand(1));
};
if (compare(*cmpType)) {
op->emitRemark("true");
More information about the llvm-branch-commits
mailing list