[Mlir-commits] [mlir] Normalize reinterpret_cast op (PR #133417)
Uday Bondhugula
llvmlistbot at llvm.org
Fri Mar 28 06:46:28 PDT 2025
================
@@ -1846,6 +1785,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
return success();
}
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
+ MemRefType memrefType = reinterpretCastOp->getType();
+ AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+ Value oldMemRef = reinterpretCastOp->getResult();
+
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+ if (oldLayoutMap.isIdentity())
+ return success();
+
+ // Fetch a new memref type after normalizing the old memref to have an
+ // identity map layout.
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
+ newMemRefType.dump();
+ if (newMemRefType == memrefType)
+ // `oldLayoutMap` couldn't be transformed to an identity map.
+ return failure();
+
+ uint64_t newRank = newMemRefType.getRank();
+ SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+ oldLayoutMap.getNumSymbols());
+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
+ Location loc = reinterpretCastOp->getLoc();
+ // As `newMemRefType` is normalized, it is unit strided.
+ SmallVector<int64_t> newStaticStrides(newRank, 1);
+ SmallVector<int64_t> newStaticOffsets(newRank, 0);
+ ArrayRef<int64_t> oldShape = memrefType.getShape();
+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
+ unsigned idx = 0;
+ SmallVector<int64_t> newStaticSizes;
+ OpBuilder b(*reinterpretCastOp);
+ // Collectthe map operands which will be used to compute the new normalized
+ // memref shape.
+ for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+ if (oldShape[i] == ShapedType::kDynamic)
+ mapOperands[i] =
+ b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ else
+ mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+ }
+ for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+ mapOperands[memrefType.getRank() + i] = oldStrides[i];
+ SmallVector<Value> newSizes;
+ ArrayRef<int64_t> newShape = newMemRefType.getShape();
+ // Compute size along all the dimensions of the new normalized memref.
+ for (unsigned i = 0; i < newRank; i++) {
+ if (newShape[i] != ShapedType::kDynamic)
----------------
bondhugula wrote:
Use `isDynamicDim`.
https://github.com/llvm/llvm-project/pull/133417
More information about the Mlir-commits
mailing list