[Mlir-commits] [mlir] [mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (PR #121833)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 6 13:52:08 PST 2025


================
@@ -91,6 +91,66 @@ struct AffineMaxOpInterface
   };
 };
 
+struct AffineDelinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *rawOp, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto op = cast<AffineDelinearizeIndexOp>(rawOp);
+    auto result = cast<OpResult>(value);
+    assert(result.getOwner() == rawOp &&
+           "bounded value isn't a result of this delinearize_index");
+    unsigned resIdx = result.getResultNumber();
+
+    AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
+
+    SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+    AffineExpr divisor = cstr.getExpr(1);
+    for (OpFoldResult basisElem :
+         ArrayRef<OpFoldResult>(basis).drop_front(resIdx + 1))
+      divisor = divisor * cstr.getExpr(basisElem);
+
+    auto resBound = cstr.bound(result);
+    if (resIdx == 0) {
+      resBound == linearIdx.floorDiv(divisor);
+      if (!basis.front().isNull())
+        resBound < cstr.getExpr(basis.front());
+      return;
+    }
+    AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
+    resBound == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
+  }
+};
+
+struct AffineLinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *rawOp, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto op = cast<AffineLinearizeIndexOp>(rawOp);
+    assert(value == op.getResult() &&
+           "value isn't the result of this linearize");
+
+    AffineExpr bound = cstr.getExpr(0);
+    AffineExpr stride = cstr.getExpr(1);
+    SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+    OperandRange multiIndex = op.getMultiIndex();
+    for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
+      unsigned argNum = multiIndex.size() - (revArgNum + 1);
+      if (argNum == 0)
+        break;
+      OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
+      bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
+      stride = stride * cstr.getExpr(length);
+    }
+    bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
+    auto resBound = cstr.bound(value);
----------------
MaheshRavishankar wrote:

Same here. The `auto` is confusing.

https://github.com/llvm/llvm-project/pull/121833


More information about the Mlir-commits mailing list