[Mlir-commits] [mlir] [mlir] Add affine.delinearize_index and affine.linearize_index ValueBoundsOpInterfaceImpl (PR #118829)

Quinn Dawkins llvmlistbot at llvm.org
Fri Dec 6 13:54:55 PST 2024


================
@@ -49,6 +49,67 @@ struct AffineApplyOpInterface
   }
 };
 
+struct AffineDelinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto delinearizeOp = cast<AffineDelinearizeIndexOp>(op);
+    auto result = cast<OpResult>(value);
+    int64_t resultIdx = result.getResultNumber();
+    assert(result.getOwner() == delinearizeOp && "invalid value");
+
+    AffineExpr linearIdxExpr = cstr.getExpr(delinearizeOp.getLinearIndex());
+    SmallVector<OpFoldResult> basis = delinearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs;
+    AffineExpr modExpr = getAffineConstantExpr(1, op->getContext());
+    AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+    for (int i = basis.size() - 1; i >= resultIdx; --i) {
+      AffineExpr basisExpr = cstr.getExpr(basis[i]);
+      modExpr = modExpr * basisExpr;
+      if (i > resultIdx)
+        strideExpr = strideExpr * basisExpr;
+    }
+    AffineExpr bound = linearIdxExpr;
+    if (resultIdx > 0)
+      bound = bound % modExpr;
+    if (resultIdx < delinearizeOp->getNumResults())
+      bound = bound.floorDiv(strideExpr);
+
+    cstr.bound(value) == bound;
+  }
+};
+
+struct AffineLinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto linearizeOp = cast<AffineLinearizeIndexOp>(op);
+    assert(value == linearizeOp.getResult() && "invalid value");
+
+    SmallVector<OpFoldResult> basis = linearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs = llvm::map_to_vector(
+        basis, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+    basisExprs.push_back(getAffineConstantExpr(1, op->getContext()));
+
+    SmallVector<OpFoldResult> indices(linearizeOp.getMultiIndex());
+    SmallVector<AffineExpr> indexExprs = llvm::map_to_vector(
+        indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
----------------
qedawkins wrote:

instead of creating the vector and then reversing, you can do `llvm::map_to_vector(llvm::reverse(indices), ...` here and above.

https://github.com/llvm/llvm-project/pull/118829


More information about the Mlir-commits mailing list