[Mlir-commits] [mlir] 2861856 - [mlir][Vector] Add utility for computing scalable value bounds (#83876)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 21 07:19:01 PDT 2024
Author: Benjamin Maxwell
Date: 2024-03-21T14:18:56Z
New Revision: 2861856baf16e43a5e465e87022c6c2c2d238969
URL: https://github.com/llvm/llvm-project/commit/2861856baf16e43a5e465e87022c6c2c2d238969
DIFF: https://github.com/llvm/llvm-project/commit/2861856baf16e43a5e465e87022c6c2c2d238969.diff
LOG: [mlir][Vector] Add utility for computing scalable value bounds (#83876)
This adds a new API built with the `ValueBoundsConstraintSet` to compute
the bounds of possibly scalable quantities. It uses knowledge of the
range of vscale (which is defined by the target architecture), to solve
for the bound as either a constant or an expression in terms of vscale.
The result is an `AffineMap` that will always take at most one
parameter, vscale, and returns a single result, which is the bound of
`value`.
The API is defined as follows:
```c++
FailureOr<ConstantOrScalableBound>
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
Value value, std::optional<int64_t> dim,
unsigned vscaleMin, unsigned vscaleMax,
presburger::BoundType boundType,
bool closedUB = true,
StopConditionFn stopCondition = nullptr);
```
Note: `ConstantOrScalableBound` is a thin wrapper over the `AffineMap`
with a utility for converting the bound to a single quantity (i.e. a
size and scalable flag).
We believe this API could prove useful downstream in IREE (which uses a
similar analysis to hoist allocas, which currently fails for scalable
vectors).
Added:
mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/test/Dialect/Vector/test-scalable-bounds.mlir
Modified:
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Dialect/Vector/IR/CMakeLists.txt
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
new file mode 100644
index 00000000000000..31e19ff1ad39f7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -0,0 +1,104 @@
+//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
+#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
+
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+namespace mlir::vector {
+
+namespace detail {
+
+/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to
+/// hide all ValueBoundsConstraintSet methods by default (as some do not use the
+/// ScalableValueBoundsConstraintSet, so may produce unexpected results).
+struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
+ using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet;
+};
+} // namespace detail
+
+/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
+struct ScalableValueBoundsConstraintSet
+ : public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
+ detail::ValueBoundsConstraintSet> {
+ ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
+ unsigned vscaleMax)
+ : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
+
+ using RTTIExtends::bound;
+ using RTTIExtends::StopConditionFn;
+
+ /// A thin wrapper over an `AffineMap` which can represent a constant bound,
+ /// or a scalable bound (in terms of vscale). The `AffineMap` will always
+ /// take at most one parameter, vscale, and returns a single result, which is
+ /// the bound of value.
+ struct ConstantOrScalableBound {
+ AffineMap map;
+
+ struct BoundSize {
+ int64_t baseSize{0};
+ bool scalable{false};
+ };
+
+ /// Get the (possibly) scalable size of the bound, returns failure if
+ /// the bound cannot be represented as a single quantity.
+ FailureOr<BoundSize> getSize() const;
+ };
+
+ /// Computes a (possibly) scalable bound for a given value. This is
+ /// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
+ /// uses knowledge of the range of vscale to compute either a constant
+ /// bound, an expression in terms of vscale, or failure if no bound can
+ /// be computed.
+ ///
+ /// The resulting `AffineMap` will always take at most one parameter,
+ /// vscale, and return a single result, which is the bound of `value`.
+ ///
+ /// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
+ /// `vscaleMax`, the resulting bound (if found), will be constant.
+ static FailureOr<ConstantOrScalableBound>
+ computeScalableBound(Value value, std::optional<int64_t> dim,
+ unsigned vscaleMin, unsigned vscaleMax,
+ presburger::BoundType boundType, bool closedUB = true,
+ StopConditionFn stopCondition = nullptr);
+
+ /// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
+ Value getVscaleValue() const { return vscale; }
+
+ /// Sets the value of vscale. Asserts if vscale has already been set.
+ void setVscale(vector::VectorScaleOp vscaleOp) {
+ assert(!vscale && "expected vscale to be unset");
+ vscale = vscaleOp.getResult();
+ }
+
+ /// The minimum possible value of vscale.
+ unsigned getVscaleMin() const { return vscaleMin; }
+
+ /// The maximum possible value of vscale.
+ unsigned getVscaleMax() const { return vscaleMax; }
+
+ static char ID;
+
+private:
+ const unsigned vscaleMin;
+ const unsigned vscaleMax;
+
+ // This will be set when the first `vector.vscale` operation is found within
+ // the `ValueBoundsOpInterface` implementation then reused from there on.
+ Value vscale = nullptr;
+};
+
+using ConstantOrScalableBound =
+ ScalableValueBoundsConstraintSet::ConstantOrScalableBound;
+
+} // namespace mlir::vector
+
+#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
diff --git a/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 00000000000000..4794bc9016c6f9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace vector {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 21775e11e07149..9bbf12d1325401 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -82,6 +82,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
@@ -174,6 +175,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
+ vector::registerValueBoundsOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 28dadfb9ecf868..b4ed0967e63f18 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/ExtensibleRTTI.h"
#include <queue>
@@ -63,7 +64,8 @@ using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
///
/// Note: Any modification of existing IR invalides the data stored in this
/// class. Adding new operations is allowed.
-class ValueBoundsConstraintSet {
+class ValueBoundsConstraintSet
+ : public llvm::RTTIExtends<ValueBoundsConstraintSet, llvm::RTTIRoot> {
protected:
/// Helper class that builds a bound for a shaped value dimension or
/// index-typed value.
@@ -107,6 +109,8 @@ class ValueBoundsConstraintSet {
};
public:
+ static char ID;
+
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
@@ -265,6 +269,16 @@ class ValueBoundsConstraintSet {
ValueBoundsConstraintSet(MLIRContext *ctx);
+ /// Populates the constraint set for a value/map without actually computing
+ /// the bound. Returns the position for the value/map (via the return value
+ /// and `posOut` output parameter).
+ int64_t populateConstraintsSet(Value value,
+ std::optional<int64_t> dim = std::nullopt,
+ StopConditionFn stopCondition = nullptr);
+ int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
+ StopConditionFn stopCondition = nullptr,
+ int64_t *posOut = nullptr);
+
/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
/// any further.
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 70f3fa8c297d4b..204462ffd047c6 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRVectorDialect
VectorOps.cpp
+ ValueBoundsOpInterfaceImpl.cpp
+ ScalableValueBoundsConstraintSet.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
new file mode 100644
index 00000000000000..6d7e3bc70f59de
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -0,0 +1,103 @@
+//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::vector {
+
+FailureOr<ConstantOrScalableBound::BoundSize>
+ConstantOrScalableBound::getSize() const {
+ if (map.isSingleConstant())
+ return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
+ if (map.getNumResults() != 1 || map.getNumInputs() != 1)
+ return failure();
+ auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
+ if (!binop || binop.getKind() != AffineExprKind::Mul)
+ return failure();
+ auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
+ if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
+ constant = cst.getValue();
+ return true;
+ }
+ return false;
+ };
+ // Match `s0 * cst` or `cst * s0`:
+ int64_t cst = 0;
+ auto lhs = binop.getLHS();
+ auto rhs = binop.getRHS();
+ if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
+ (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
+ return BoundSize{cst, /*scalable=*/true};
+ }
+ return failure();
+}
+
+char ScalableValueBoundsConstraintSet::ID = 0;
+
+FailureOr<ConstantOrScalableBound>
+ScalableValueBoundsConstraintSet::computeScalableBound(
+ Value value, std::optional<int64_t> dim, unsigned vscaleMin,
+ unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
+ StopConditionFn stopCondition) {
+ using namespace presburger;
+
+ assert(vscaleMin <= vscaleMax);
+ ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
+ vscaleMax);
+
+ int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
+
+ // Project out all variables apart from vscale.
+ // This should result in constraints in terms of vscale only.
+ scalableCstr.projectOut(
+ [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
+
+ assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
+ scalableCstr.positionToValueDim.size() &&
+ "inconsistent mapping state");
+
+ // Check that the only symbols left are vscale.
+ for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
+ if (i == pos)
+ continue;
+ if (scalableCstr.positionToValueDim[i] !=
+ ValueDim(scalableCstr.getVscaleValue(),
+ ValueBoundsConstraintSet::kIndexValue)) {
+ return failure();
+ }
+ }
+
+ SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
+ scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
+ &upperBound, closedUB);
+
+ auto invalidBound = [](auto &bound) {
+ return !bound[0] || bound[0].getNumResults() != 1;
+ };
+
+ AffineMap bound = [&] {
+ if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
+ lowerBound[0] == lowerBound[0]) {
+ return lowerBound[0];
+ } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
+ return lowerBound[0];
+ } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
+ return upperBound[0];
+ }
+ return AffineMap{};
+ }();
+
+ if (!bound)
+ return failure();
+
+ return ConstantOrScalableBound{bound};
+}
+
+} // namespace mlir::vector
diff --git a/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..ca95072d9bb0f2
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,51 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir::vector {
+namespace {
+
+struct VectorScaleOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<VectorScaleOpInterface,
+ VectorScaleOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto *scalableCstr = dyn_cast<ScalableValueBoundsConstraintSet>(&cstr);
+ if (!scalableCstr)
+ return;
+ auto vscaleOp = cast<VectorScaleOp>(op);
+ assert(value == vscaleOp.getResult() && "invalid value");
+ if (auto vscale = scalableCstr->getVscaleValue()) {
+ // All copies of vscale are equivalent.
+ scalableCstr->bound(value) == cstr.getExpr(vscale);
+ } else {
+ // We know vscale is confined to [vscaleMin, vscaleMax].
+ scalableCstr->bound(value) >= scalableCstr->getVscaleMin();
+ scalableCstr->bound(value) <= scalableCstr->getVscaleMax();
+ scalableCstr->setVscale(vscaleOp);
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir::vector
+
+void mlir::vector::registerValueBoundsOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ vector::VectorScaleOp::attachInterface<vector::VectorScaleOpInterface>(
+ *ctx);
+ });
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 85abc2df894797..06ec3f4e135e9f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -70,6 +70,8 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
: builder(ctx) {}
+char ValueBoundsConstraintSet::ID = 0;
+
#ifndef NDEBUG
static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
if (value.getType().isIndex()) {
@@ -471,55 +473,87 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
closedUB);
}
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+ StopConditionFn stopCondition, bool closedUB) {
+ ValueDimList valueDims;
+ for (Value v : operands) {
+ assert(v.getType().isIndex() && "expected index type");
+ valueDims.emplace_back(v, std::nullopt);
+ }
+ return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+}
+
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList operands,
StopConditionFn stopCondition, bool closedUB) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
ValueBoundsConstraintSet cstr(map.getContext());
- int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+ int64_t pos = 0;
+ if (stopCondition) {
+ cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
+ } else {
+ // No stop condition specified: Keep adding constraints until a bound could
+ // be computed.
+ cstr.populateConstraintsSet(
+ map, operands,
+ [&](Value v, std::optional<int64_t> dim) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ },
+ &pos);
+ }
+ // Compute constant bound for `valueDim`.
+ int64_t ubAdjustment = closedUB ? 0 : 1;
+ if (auto bound = cstr.cstr.getConstantBound64(type, pos))
+ return type == BoundType::UB ? *bound + ubAdjustment : *bound;
+ return failure();
+}
+
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(
+ Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+ Builder(value.getContext()).getAffineDimExpr(0));
+ return populateConstraintsSet(map, {{value, dim}}, stopCondition);
+}
+
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(
+ AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
+ int64_t *posOut) {
+ assert(map.getNumResults() == 1 && "expected affine map with one result");
+ int64_t pos = insert(/*isSymbol=*/false);
+ if (posOut)
+ *posOut = pos;
// Add map and operands to the constraint set. Dimensions are converted to
// symbols. All operands are added to the worklist.
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
- return cstr.getExpr(v.first, v.second);
+ return getExpr(v.first, v.second);
};
SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
SmallVector<AffineExpr> symReplacements = llvm::to_vector(
llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
- cstr.addBound(
+ addBound(
presburger::BoundType::EQ, pos,
map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
if (stopCondition) {
- cstr.processWorklist(stopCondition);
+ processWorklist(stopCondition);
} else {
- // No stop condition specified: Keep adding constraints until a bound could
- // be computed.
- cstr.processWorklist(
- /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
- return cstr.cstr.getConstantBound64(type, pos).has_value();
- });
+ // No stop condition specified: Keep adding constraints until the worklist
+ // is empty.
+ processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
}
- // Compute constant bound for `valueDim`.
- int64_t ubAdjustment = closedUB ? 0 : 1;
- if (auto bound = cstr.cstr.getConstantBound64(type, pos))
- return type == BoundType::UB ? *bound + ubAdjustment : *bound;
- return failure();
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
- StopConditionFn stopCondition, bool closedUB) {
- ValueDimList valueDims;
- for (Value v : operands) {
- assert(v.getType().isIndex() && "expected index type");
- valueDims.emplace_back(v, std::nullopt);
- }
- return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+ return pos;
}
FailureOr<int64_t>
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
new file mode 100644
index 00000000000000..245a6f5c13ac3d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -cse -verify-diagnostics \
+// RUN: -verify-diagnostics -split-input-file | FileCheck %s
+
+#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 32400, s0)>
+#map_dim_j = affine_map<(d0)[s0] -> (-d0 + 16, s0)>
+
+// Here the upper bound for min_i is 4 x vscale, as we know 4 x vscale is
+// always less than 32400. The bound for min_j is 16, as 16 is always less
+// 4 x vscale_max (vscale_max is the UB for vscale).
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_0:.*]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK-LABEL: @fixed_size_loop_nest
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[UB_i:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_0]]()[%[[VSCALE]]]
+// CHECK-DAG: %[[UB_j:.*]] = arith.constant 16 : index
+// CHECK: "test.some_use"(%[[UB_i]], %[[UB_j]]) : (index, index) -> ()
+func.func @fixed_size_loop_nest() {
+ %c16 = arith.constant 16 : index
+ %c32400 = arith.constant 32400 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %vscale = vector.vscale
+ %c4_vscale = arith.muli %vscale, %c4 : index
+ scf.for %i = %c0 to %c32400 step %c4_vscale {
+ %min_i = affine.min #map_dim_i(%i)[%c4_vscale]
+ scf.for %j = %c0 to %c16 step %c4_vscale {
+ %min_j = affine.min #map_dim_j(%j)[%c4_vscale]
+ %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
+ }
+ }
+ return
+}
+
+// -----
+
+#map_dynamic_dim = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)>
+
+// Here upper bounds for both min_i and min_j are both (conservatively)
+// 4 x vscale, as we know that is always the largest value they could take. As
+// if `dim < 4 x vscale` then 4 x vscale is an overestimate, and if
+// `dim > 4 x vscale` then the min will be clamped to 4 x vscale.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_1:.*]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK-LABEL: @dynamic_size_loop_nest
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[UB_ij:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_1]]()[%[[VSCALE]]]
+// CHECK: "test.some_use"(%[[UB_ij]], %[[UB_ij]]) : (index, index) -> ()
+func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %vscale = vector.vscale
+ %c4_vscale = arith.muli %vscale, %c4 : index
+ scf.for %i = %c0 to %dim0 step %c4_vscale {
+ %min_i = affine.min #map_dynamic_dim(%i)[%c4_vscale, %dim0]
+ scf.for %j = %c0 to %dim1 step %c4_vscale {
+ %min_j = affine.min #map_dynamic_dim(%j)[%c4_vscale, %dim1]
+ %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
+ }
+ }
+ return
+}
+
+// -----
+
+// Here the bound is just a value + a constant.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_2:.*]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @add_to_vscale
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_2]]()[%[[VSCALE]]]
+// CHECK: "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
+func.func @add_to_vscale() {
+ %vscale = vector.vscale
+ %c8 = arith.constant 8 : index
+ %vscale_plus_c8 = arith.addi %vscale, %c8 : index
+ %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound) : (index) -> ()
+ return
+}
+
+// -----
+
+// Here we know vscale is always 2 so we get a constant bound.
+
+// CHECK-LABEL: @vscale_fixed_size
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: "test.some_use"(%[[C2]]) : (index) -> ()
+func.func @vscale_fixed_size() {
+ %vscale = vector.vscale
+ %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index
+ "test.some_use"(%bound) : (index) -> ()
+ return
+}
+
+// -----
+
+// Here we don't know the upper bound (%a is underspecified)
+
+func.func @unknown_bound(%a: index) {
+ %vscale = vector.vscale
+ %vscale_plus_a = arith.muli %vscale, %a : index
+ // expected-error @below{{could not reify bound}}
+ %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound) : (index) -> ()
+ return
+}
+
+// -----
+
+// Here we have two vscale values (that have not been CSE'd), but they should
+// still be treated as equivalent.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_3:.*]] = affine_map<()[s0] -> (s0 * 6)>
+
+// CHECK-LABEL: @duplicate_vscale_values
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_3]]()[%[[VSCALE]]]
+// CHECK: "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
+func.func @duplicate_vscale_values() {
+ %c4 = arith.constant 4 : index
+ %vscale_0 = vector.vscale
+
+ %c2 = arith.constant 2 : index
+ %vscale_1 = vector.vscale
+
+ %c4_vscale = arith.muli %vscale_0, %c4 : index
+ %c2_vscale = arith.muli %vscale_1, %c2 : index
+ %add = arith.addi %c2_vscale, %c4_vscale : index
+
+ %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound) : (index) -> ()
+ return
+}
+
+// -----
+
+// Test some non-scalable code to ensure that works too:
+
+#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 1024, s0)>
+
+// CHECK-LABEL: @non_scalable_code
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: "test.some_use"(%[[C4]]) : (index) -> ()
+func.func @non_scalable_code() {
+ %c1024 = arith.constant 1024 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ scf.for %i = %c0 to %c1024 step %c4 {
+ %min_i = affine.min #map_dim_i(%i)[%c4]
+ %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ "test.some_use"(%bound_i) : (index) -> ()
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..5e160b720db627 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
@@ -75,7 +76,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
WalkResult result = funcOp.walk([&](Operation *op) {
// Look for test.reify_bound ops.
if (op->getName().getStringRef() == "test.reify_bound" ||
- op->getName().getStringRef() == "test.reify_constant_bound") {
+ op->getName().getStringRef() == "test.reify_constant_bound" ||
+ op->getName().getStringRef() == "test.reify_scalable_bound") {
if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
!op->getResultTypes()[0].isIndex()) {
op->emitOpError("invalid op");
@@ -110,6 +112,9 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
bool constant =
op->getName().getStringRef() == "test.reify_constant_bound";
+ bool scalable = !constant && op->getName().getStringRef() ==
+ "test.reify_scalable_bound";
+
// Prepare stop condition. By default, reify in terms of the op's
// operands. No stop condition is used when a constant was requested.
std::function<bool(Value, std::optional<int64_t>)> stopCondition =
@@ -137,6 +142,37 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
if (succeeded(reifiedConst))
reified =
FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+ } else if (scalable) {
+ unsigned vscaleMin = 0;
+ unsigned vscaleMax = 0;
+ if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr)) {
+ vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+ } else {
+ op->emitOpError("expected `vscale_min` to be provided");
+ return WalkResult::skip();
+ }
+ if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr)) {
+ vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+ } else {
+ op->emitOpError("expected `vscale_max` to be provided");
+ return WalkResult::skip();
+ }
+
+ auto loc = op->getLoc();
+ auto reifiedScalable =
+ vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+ value, dim, vscaleMin, vscaleMax, *boundType);
+ if (succeeded(reifiedScalable)) {
+ SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
+ vscaleOperand;
+ if (reifiedScalable->map.getNumInputs() == 1) {
+ // The only possible input to the bound is vscale.
+ vscaleOperand.push_back(std::make_pair(
+ rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
+ }
+ reified = affine::materializeComputedBound(
+ rewriter, loc, reifiedScalable->map, vscaleOperand);
+ }
} else {
if (dim) {
if (useArithOps) {
More information about the Mlir-commits
mailing list