[Mlir-commits] [mlir] [MLIR][Affine] Add vector support to affine.linearize_index and affine.delinearize_index (PR #188369)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Mar 26 12:50:02 PDT 2026


================
@@ -29,19 +31,85 @@ using namespace mlir;
 using namespace mlir::affine;
 
 namespace {
+
+/// Create a constant splat of the given type with the given integer value.
+static Value createTypedConstant(OpBuilder &b, Location loc, Type type,
+                                 int64_t value) {
+  if (auto vecTy = dyn_cast<VectorType>(type))
+    return arith::ConstantOp::create(
+        b, loc, DenseElementsAttr::get(vecTy, b.getIndexAttr(value)));
+  return arith::ConstantIndexOp::create(b, loc, value);
+}
+
+/// Materialize an OpFoldResult (which represents a scalar index or constant)
+/// as a Value matching the given target type. For vector target types, scalar
+/// constants are splatted. Returns failure for dynamic basis with vector types
+/// since that requires vector.broadcast which is not available here.
+static FailureOr<Value> materializeBasis(OpBuilder &b, Location loc,
+                                         OpFoldResult ofr, Type targetType) {
+  std::optional<int64_t> cst = getConstantIntValue(ofr);
+  if (cst)
+    return createTypedConstant(b, loc, targetType, *cst);
+  // Dynamic scalar basis value. For scalar target types, return as-is.
+  if (isa<IndexType>(targetType))
+    return getValueOrCreateConstantIndexOp(b, loc, ofr);
+  // Dynamic scalar basis with vector target type -- would need
+  // vector.broadcast, bail out.
+  return failure();
+}
+
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
 struct LowerDelinearizeIndexOps
     : public OpRewritePattern<AffineDelinearizeIndexOp> {
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
-    if (failed(multiIndex))
-      return failure();
-    rewriter.replaceOp(op, *multiIndex);
+    // For scalar types, use the existing affine lowering path.
+    if (isa<IndexType>(op.getLinearIndex().getType())) {
+      FailureOr<SmallVector<Value>> multiIndex =
+          delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                           op.getEffectiveBasis(), /*hasOuterBound=*/false);
+      if (failed(multiIndex))
+        return failure();
+      rewriter.replaceOp(op, *multiIndex);
+      return success();
+    }
+
+    // Vector lowering: emit arith div/rem ops (which work element-wise on
+    // vectors).
+    Location loc = op.getLoc();
+    Value linearIndex = op.getLinearIndex();
+    Type type = linearIndex.getType();
+    SmallVector<OpFoldResult> basis = op.getEffectiveBasis();
+
+    // Compute cumulative products of basis from the right. These serve as
+    // divisors: for basis (B0, B1, B2), the divisors are (B1*B2, B2).
+    SmallVector<Value> divisors;
+    Value cumulativeProd = createTypedConstant(rewriter, loc, type, 1);
+    for (OpFoldResult basisElem : llvm::reverse(basis)) {
+      FailureOr<Value> basisVal =
+          materializeBasis(rewriter, loc, basisElem, type);
+      if (failed(basisVal))
+        return failure();
+      cumulativeProd =
+          arith::MulIOp::create(rewriter, loc, cumulativeProd, *basisVal);
+      divisors.push_back(cumulativeProd);
+    }
+
+    // Emit div/mod pairs from the most-significant dimension to the least.
+    SmallVector<Value> results;
+    results.reserve(divisors.size() + 1);
+    Value residual = linearIndex;
+    for (Value divisor : llvm::reverse(divisors)) {
----------------
krzysz00 wrote:

..., no that's the other pass

https://github.com/llvm/llvm-project/pull/188369


More information about the Mlir-commits mailing list