[Mlir-commits] [mlir] 041bc48 - [mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds
Matthias Springer
llvmlistbot at llvm.org
Thu Apr 6 18:53:20 PDT 2023
Author: Matthias Springer
Date: 2023-04-07T10:48:19+09:00
New Revision: 041bc485bf2122b238eb1a336d3a38168feb8eaa
URL: https://github.com/llvm/llvm-project/commit/041bc485bf2122b238eb1a336d3a38168feb8eaa
DIFF: https://github.com/llvm/llvm-project/commit/041bc485bf2122b238eb1a336d3a38168feb8eaa.diff
LOG: [mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds
This change also adds support for `affine.min` and `affine.max` ops.
Differential Revision: https://reviews.llvm.org/D145787
Added:
Modified:
mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 2abd329b51620..e957fb1cd0dbd 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -51,14 +51,15 @@ FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
/// Reify a bound for the given index-typed value or shape dimension size in
/// terms of the owning op's operands. `dim` must be `nullopt` if and only if
-/// `value` is index-typed.
+/// `value` is index-typed. LB and EQ bounds are closed, UB bounds are open.
FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
presburger::BoundType type, Value value,
std::optional<int64_t> dim);
/// Reify a bound for the given index-typed value or shape dimension size in
/// terms of SSA values for which `stopCondition` is met. `dim` must be
-/// `nullopt` if and only if `value` is index-typed.
+/// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are
+/// closed, UB bounds are open.
///
/// Example:
/// %0 = arith.addi %a, %b : index
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index ed8b9a74e2e77..0036023a8a015 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -38,6 +38,48 @@ struct AffineApplyOpInterface
}
};
+struct AffineMinOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
+ AffineMinOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto minOp = cast<AffineMinOp>(op);
+ assert(value == minOp.getResult() && "invalid value");
+
+ // Align affine map results with dims/symbols in the constraint set.
+ for (AffineExpr expr : minOp.getAffineMap().getResults()) {
+ SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+ minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+ SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+ minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+ AffineExpr bound =
+ expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+ cstr.bound(value) <= bound;
+ }
+ };
+};
+
+struct AffineMaxOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<AffineMaxOpInterface,
+ AffineMaxOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto maxOp = cast<AffineMaxOp>(op);
+ assert(value == maxOp.getResult() && "invalid value");
+
+ // Align affine map results with dims/symbols in the constraint set.
+ for (AffineExpr expr : maxOp.getAffineMap().getResults()) {
+ SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+ maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+ SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+ maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+ AffineExpr bound =
+ expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+ cstr.bound(value) >= bound;
+ }
+ };
+};
+
} // namespace
} // namespace mlir
@@ -45,5 +87,7 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
+ AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
+ AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 73757b611520a..8db5e6865646b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -214,9 +214,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
assertValidValueDim(value, dim);
#endif // NDEBUG
- // Only EQ bounds are supported at the moment.
- assert(type == BoundType::EQ && "unsupported bound type");
-
Builder b(value.getContext());
mapOperands.clear();
@@ -249,16 +246,39 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
SmallVector<AffineMap> lb(1), ub(1);
cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
/*getClosedUB=*/true);
+
// Note: There are TODOs in the implementation of `getSliceBounds`. In such a
// case, no lower/upper bound can be computed at the moment.
- if (lb.empty() || !lb[0] || ub.empty() || !ub[0] ||
- lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1)
+ // EQ, UB bounds: upper bound is needed.
+ if ((type != BoundType::LB) &&
+ (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
return failure();
+ // EQ, LB bounds: lower bound is needed.
+ if ((type != BoundType::UB) &&
+ (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
+ return failure();
+
+ // TODO: Generate an affine map with multiple results.
+ if (type != BoundType::LB)
+ assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
+ "multiple bounds not supported");
+ if (type != BoundType::UB)
+ assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
+ "multiple bounds not supported");
- // Look for same lower and upper bound: EQ bound.
- if (ub[0] != lb[0])
+ // EQ bound: lower and upper bound must match.
+ if (type == BoundType::EQ && ub[0] != lb[0])
return failure();
+ AffineMap bound;
+ if (type == BoundType::EQ || type == BoundType::LB) {
+ bound = lb[0];
+ } else {
+ // Computed UB is a closed bound. Turn into an open bound.
+ bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
+ ub[0].getResult(0) + 1);
+ }
+
// Gather all SSA values that are used in the computed bound.
assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
"inconsistent mapping state");
@@ -273,10 +293,10 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
bool used = false;
bool isDim = i < cstr.cstr.getNumDimVars();
if (isDim) {
- if (lb[0].isFunctionOfDim(i))
+ if (bound.isFunctionOfDim(i))
used = true;
} else {
- if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
+ if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
used = true;
}
@@ -312,7 +332,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
mapOperands.push_back(std::make_pair(value, dim));
}
- resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols,
+ resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
numDims, numSymbols);
return success();
}
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 436ec4a8326cb..338c48c5b210b 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -12,3 +12,49 @@ func.func @affine_apply(%a: index, %b: index) -> index {
%1 = "test.reify_bound"(%0) : (index) -> (index)
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @affine_max_lb(
+// CHECK-SAME: %[[a:.*]]: index
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: return %[[c2]]
+func.func @affine_max_lb(%a: index) -> (index) {
+ // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always
+ // returns the constant one at the moment.
+ %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+ %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+ return %2 : index
+}
+
+// -----
+
+func.func @affine_max_ub(%a: index) -> (index) {
+ %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+ // expected-error @below{{could not reify bound}}
+ %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+ return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @affine_min_ub(
+// CHECK-SAME: %[[a:.*]]: index
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: return %[[c3]]
+func.func @affine_min_ub(%a: index) -> (index) {
+ // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always
+ // returns the constant one at the moment.
+ %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+ %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+ return %2 : index
+}
+
+// -----
+
+func.func @affine_min_lb(%a: index) -> (index) {
+ %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+ // expected-error @below{{could not reify bound}}
+ %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+ return %2 : index
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 582831565aa73..e2a06a51f4636 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -16,6 +16,7 @@
#define PASS_NAME "test-affine-reify-value-bounds"
using namespace mlir;
+using mlir::presburger::BoundType;
namespace {
@@ -45,6 +46,16 @@ struct TestReifyValueBounds
} // namespace
+FailureOr<BoundType> parseBoundType(std::string type) {
+ if (type == "EQ")
+ return BoundType::EQ;
+ if (type == "LB")
+ return BoundType::LB;
+ if (type == "UB")
+ return BoundType::UB;
+ return failure();
+}
+
/// Look for "test.reify_bound" ops in the input and replace their results with
/// the reified values.
static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
@@ -67,6 +78,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return WalkResult::skip();
}
+ // Get bound type.
+ std::string boundTypeStr = "EQ";
+ if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
+ boundTypeStr = boundTypeAttr.str();
+ auto boundType = parseBoundType(boundTypeStr);
+ if (failed(boundType)) {
+ op->emitOpError("invalid op");
+ return WalkResult::interrupt();
+ }
+
+ // Get shape dimension (if any).
auto dim = value.getType().isIndex()
? std::nullopt
: std::make_optional<int64_t>(
@@ -77,8 +99,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
FailureOr<OpFoldResult> reified;
if (!reifyToFuncArgs) {
// Reify in terms of the op's operands.
- reified = reifyValueBound(rewriter, op->getLoc(),
- presburger::BoundType::EQ, value, dim);
+ reified =
+ reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim);
} else {
// Reify in terms of function block arguments.
auto stopCondition = [](Value v) {
@@ -88,9 +110,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return isa<FunctionOpInterface>(
bbArg.getParentBlock()->getParentOp());
};
- reified =
- reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ,
- value, dim, stopCondition);
+ reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
+ dim, stopCondition);
}
if (failed(reified)) {
op->emitOpError("could not reify bound");
More information about the Mlir-commits
mailing list