[Mlir-commits] [mlir] [mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (PR #121833)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jan 6 13:25:52 PST 2025
================
@@ -91,6 +91,66 @@ struct AffineMaxOpInterface
};
};
+struct AffineDelinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineDelinearizeIndexOp>(rawOp);
+ auto result = cast<OpResult>(value);
+ assert(result.getOwner() == rawOp &&
+ "bounded value isn't a result of this delinearize_index");
+ unsigned resIdx = result.getResultNumber();
+
+ AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
+
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ AffineExpr divisor = cstr.getExpr(1);
+ for (OpFoldResult basisElem :
+ ArrayRef<OpFoldResult>(basis).drop_front(resIdx + 1))
+ divisor = divisor * cstr.getExpr(basisElem);
+
+ auto resBound = cstr.bound(result);
+ if (resIdx == 0) {
+ resBound == linearIdx.floorDiv(divisor);
+ if (!basis.front().isNull())
+ resBound < cstr.getExpr(basis.front());
+ return;
+ }
+ AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
+ resBound == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
+ }
+};
+
+struct AffineLinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineLinearizeIndexOp>(rawOp);
+ assert(value == op.getResult() &&
+ "value isn't the result of this linearize");
+
+ AffineExpr bound = cstr.getExpr(0);
+ AffineExpr stride = cstr.getExpr(1);
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ OperandRange multiIndex = op.getMultiIndex();
+ for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
+ unsigned argNum = multiIndex.size() - (revArgNum + 1);
+ if (argNum == 0)
+ break;
+ OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
+ bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
+ stride = stride * cstr.getExpr(length);
----------------
krzysz00 wrote:
We need the right-to-left partial products
https://github.com/llvm/llvm-project/pull/121833
More information about the Mlir-commits
mailing list