[Mlir-commits] [mlir] [mlir] Add Memref Normalization support for reinterpret_cast op (PR #133417)

Uday Bondhugula llvmlistbot at llvm.org
Fri Apr 25 03:25:05 PDT 2025


================
@@ -1216,53 +1138,55 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   if (usePositions.empty())
     return success();
 
-  if (usePositions.size() > 1) {
-    // TODO: extend it for this case when needed (rare).
-    assert(false && "multiple dereferencing uses in a single op not supported");
-    return failure();
-  }
-
   unsigned memRefOperandPos = usePositions.front();
 
   OpBuilder builder(op);
   // The following checks if op is dereferencing memref and performs the access
   // index rewrites.
-  auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
-  if (!affMapAccInterface) {
+  if (!isDereferencingOp(op)) {
     if (!allowNonDereferencingOps) {
       // Failure: memref used in a non-dereferencing context (potentially
       // escapes); no replacement in these cases unless allowNonDereferencingOps
       // is set.
       return failure();
     }
+    for (unsigned pos : usePositions)
+      op->setOperand(pos, newMemRef);
+    return success();
+  }
 
-    // Check if it is a memref.load
-    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
-    bool isReductionLike =
-        indexRemap.getNumResults() < indexRemap.getNumInputs();
-    if (!memrefLoad || !isReductionLike) {
-      op->setOperand(memRefOperandPos, newMemRef);
-      return success();
-    }
+  if (usePositions.size() > 1) {
+    // TODO: extend it for this case when needed (rare).
+    LLVM_DEBUG(llvm::dbgs()
+               << "multiple dereferencing uses in a single op not supported");
+    return failure();
+  }
 
-    return transformMemRefLoadWithReducedRank(
-        op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
-        symbolOperands, indexRemap);
+  // Perform index rewrites for the dereferencing op and then replace the op.
+  SmallVector<Value, 4> oldMapOperands;
+  AffineMap oldMap;
+  unsigned oldMemRefNumIndices = oldMemRefRank;
+  auto startIdx = op->operand_begin() + memRefOperandPos + 1;
+  auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
+  if (affMapAccInterface) {
+    // If `op` implements AffineMapAccessInterface, we can get the indices by
+    // quering the number of map operands from the operand list from a certain
+    // offset (`memRefOperandPos` in this case).
+    NamedAttribute oldMapAttrPair =
+        affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
+    oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
+    oldMemRefNumIndices = oldMap.getNumInputs();
+    oldMapOperands.assign(startIdx, startIdx + oldMemRefNumIndices);
+  } else {
+    oldMapOperands.assign(startIdx, startIdx + oldMemRefRank);
----------------
bondhugula wrote:

Set oldMemRefNumIndices to `oldMemRefRank` in the `else` part and then `assign` in a single line outside?

https://github.com/llvm/llvm-project/pull/133417


More information about the Mlir-commits mailing list