[Mlir-commits] [mlir] [mlir] Add Memref Normalization support for reinterpret_cast op (PR #133417)

Uday Bondhugula llvmlistbot at llvm.org
Fri Apr 25 03:25:05 PDT 2025


================
@@ -1845,6 +1781,95 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
   return success();
 }
 
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
+  MemRefType memrefType = reinterpretCastOp.getType();
+  AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+  Value oldMemRef = reinterpretCastOp.getResult();
+
+  // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+  if (oldLayoutMap.isIdentity())
+    return success();
+
+  // Fetch a new memref type after normalizing the old memref to have an
+  // identity map layout.
+  MemRefType newMemRefType = normalizeMemRefType(memrefType);
+  newMemRefType.dump();
+  if (newMemRefType == memrefType)
+    // `oldLayoutMap` couldn't be transformed to an identity map.
+    return failure();
+
+  uint64_t newRank = newMemRefType.getRank();
+  SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+                                 oldLayoutMap.getNumSymbols());
+  SmallVector<Value> oldStrides = reinterpretCastOp.getStrides();
+  Location loc = reinterpretCastOp.getLoc();
+  // As `newMemRefType` is normalized, it is unit strided.
+  SmallVector<int64_t> newStaticStrides(newRank, 1);
+  SmallVector<int64_t> newStaticOffsets(newRank, 0);
+  ArrayRef<int64_t> oldShape = memrefType.getShape();
+  mlir::ValueRange oldSizes = reinterpretCastOp.getSizes();
+  unsigned idx = 0;
+  SmallVector<int64_t> newStaticSizes;
+  OpBuilder b(reinterpretCastOp);
+  // Collectthe map operands which will be used to compute the new normalized
+  // memref shape.
+  for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+    if (memrefType.isDynamicDim(i))
+      mapOperands[i] =
+          b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+                                  b.create<arith::ConstantIndexOp>(loc, 1));
+    else
+      mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+  }
+  for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+    mapOperands[memrefType.getRank() + i] = oldStrides[i];
+  SmallVector<Value> newSizes;
+  ArrayRef<int64_t> newShape = newMemRefType.getShape();
+  // Compute size along all the dimensions of the new normalized memref.
+  for (unsigned i = 0; i < newRank; i++) {
+    if (!newMemRefType.isDynamicDim(i))
+      continue;
+    newSizes.push_back(b.create<AffineApplyOp>(
+        loc,
+        AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
+                       oldLayoutMap.getResult(i)),
+        mapOperands));
+  }
+  for (unsigned i = 0, e = newSizes.size(); i < e; i++)
+    newSizes[i] =
+        b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
+                                b.create<arith::ConstantIndexOp>(loc, 1));
+  // Create the new reinterpret_cast op.
+  memref::ReinterpretCastOp newReinterpretCast =
----------------
bondhugula wrote:

`auto`

https://github.com/llvm/llvm-project/pull/133417


More information about the Mlir-commits mailing list