[Mlir-commits] [mlir] [mlir] Add affine.delinearize_index and affine.linearize_index ValueBoundsOpInterfaceImpl (PR #118829)
Quinn Dawkins
llvmlistbot at llvm.org
Fri Dec 6 13:54:55 PST 2024
================
@@ -49,6 +49,67 @@ struct AffineApplyOpInterface
}
};
+struct AffineDelinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto delinearizeOp = cast<AffineDelinearizeIndexOp>(op);
+ auto result = cast<OpResult>(value);
+ int64_t resultIdx = result.getResultNumber();
+ assert(result.getOwner() == delinearizeOp && "invalid value");
+
+ AffineExpr linearIdxExpr = cstr.getExpr(delinearizeOp.getLinearIndex());
+ SmallVector<OpFoldResult> basis = delinearizeOp.getMixedBasis();
+ SmallVector<AffineExpr> basisExprs;
+ AffineExpr modExpr = getAffineConstantExpr(1, op->getContext());
+ AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+ for (int i = basis.size() - 1; i >= resultIdx; --i) {
+ AffineExpr basisExpr = cstr.getExpr(basis[i]);
+ modExpr = modExpr * basisExpr;
+ if (i > resultIdx)
+ strideExpr = strideExpr * basisExpr;
+ }
+ AffineExpr bound = linearIdxExpr;
+ if (resultIdx > 0)
+ bound = bound % modExpr;
+ if (resultIdx < delinearizeOp->getNumResults())
+ bound = bound.floorDiv(strideExpr);
+
+ cstr.bound(value) == bound;
+ }
+};
+
+struct AffineLinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto linearizeOp = cast<AffineLinearizeIndexOp>(op);
+ assert(value == linearizeOp.getResult() && "invalid value");
+
+ SmallVector<OpFoldResult> basis = linearizeOp.getMixedBasis();
+ SmallVector<AffineExpr> basisExprs = llvm::map_to_vector(
+ basis, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+ basisExprs.push_back(getAffineConstantExpr(1, op->getContext()));
+
+ SmallVector<OpFoldResult> indices(linearizeOp.getMultiIndex());
+ SmallVector<AffineExpr> indexExprs = llvm::map_to_vector(
+ indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
----------------
qedawkins wrote:
instead of creating the vector and then reversing, you can do `llvm::map_to_vector(llvm::reverse(indices), ...` here and above.
https://github.com/llvm/llvm-project/pull/118829
More information about the Mlir-commits
mailing list