[Mlir-commits] [mlir] [MLIR][Affine] Add vector support to affine.linearize_index and affine.delinearize_index (PR #188369)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Mar 26 12:50:02 PDT 2026
================
@@ -29,19 +31,85 @@ using namespace mlir;
using namespace mlir::affine;
namespace {
+
+/// Create a constant splat of the given type with the given integer value.
+static Value createTypedConstant(OpBuilder &b, Location loc, Type type,
+ int64_t value) {
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ return arith::ConstantOp::create(
+ b, loc, DenseElementsAttr::get(vecTy, b.getIndexAttr(value)));
+ return arith::ConstantIndexOp::create(b, loc, value);
+}
+
+/// Materialize an OpFoldResult (which represents a scalar index or constant)
+/// as a Value matching the given target type. For vector target types, scalar
+/// constants are splatted. Returns failure for dynamic basis with vector types
+/// since that requires vector.broadcast which is not available here.
+static FailureOr<Value> materializeBasis(OpBuilder &b, Location loc,
+ OpFoldResult ofr, Type targetType) {
+ std::optional<int64_t> cst = getConstantIntValue(ofr);
+ if (cst)
+ return createTypedConstant(b, loc, targetType, *cst);
+ // Dynamic scalar basis value. For scalar target types, return as-is.
+ if (isa<IndexType>(targetType))
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ // Dynamic scalar basis with vector target type -- would need
+ // vector.broadcast, bail out.
+ return failure();
+}
+
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- FailureOr<SmallVector<Value>> multiIndex =
- delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
- op.getEffectiveBasis(), /*hasOuterBound=*/false);
- if (failed(multiIndex))
- return failure();
- rewriter.replaceOp(op, *multiIndex);
+ // For scalar types, use the existing affine lowering path.
+ if (isa<IndexType>(op.getLinearIndex().getType())) {
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+ op.getEffectiveBasis(), /*hasOuterBound=*/false);
+ if (failed(multiIndex))
+ return failure();
+ rewriter.replaceOp(op, *multiIndex);
+ return success();
+ }
+
+ // Vector lowering: emit arith div/rem ops (which work element-wise on
+ // vectors).
+ Location loc = op.getLoc();
+ Value linearIndex = op.getLinearIndex();
+ Type type = linearIndex.getType();
+ SmallVector<OpFoldResult> basis = op.getEffectiveBasis();
+
+ // Compute cumulative products of basis from the right. These serve as
+ // divisors: for basis (B0, B1, B2), the divisors are (B1*B2, B2).
+ SmallVector<Value> divisors;
+ Value cumulativeProd = createTypedConstant(rewriter, loc, type, 1);
+ for (OpFoldResult basisElem : llvm::reverse(basis)) {
+ FailureOr<Value> basisVal =
+ materializeBasis(rewriter, loc, basisElem, type);
+ if (failed(basisVal))
+ return failure();
+ cumulativeProd =
+ arith::MulIOp::create(rewriter, loc, cumulativeProd, *basisVal);
+ divisors.push_back(cumulativeProd);
+ }
+
+ // Emit div/mod pairs from the most-significant dimension to the least.
+ SmallVector<Value> results;
+ results.reserve(divisors.size() + 1);
+ Value residual = linearIndex;
+ for (Value divisor : llvm::reverse(divisors)) {
----------------
krzysz00 wrote:
..., no that's the other pass
https://github.com/llvm/llvm-project/pull/188369
More information about the Mlir-commits
mailing list