[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)

Kai Sasaki llvmlistbot at llvm.org
Tue Sep 10 00:02:00 PDT 2024


================
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
       // is set.
       return failure();
     }
-    op->setOperand(memRefOperandPos, newMemRef);
+
+    // Check if it is a memref.load
+    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+    bool isReductionLike =
+        indexRemap.getNumResults() < indexRemap.getNumInputs();
+    if (!memrefLoad || !isReductionLike) {
+      op->setOperand(memRefOperandPos, newMemRef);
+      return success();
+    }
+
+    unsigned oldMapNumInputs = oldMemRefRank;
+    SmallVector<Value, 4> oldMapOperands(
+        op->operand_begin() + memRefOperandPos + 1,
+        op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+    SmallVector<Value, 4> oldMemRefOperands;
+    oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+    SmallVector<Value, 4> remapOperands;
+    remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+                          symbolOperands.size());
+    remapOperands.append(extraOperands.begin(), extraOperands.end());
+    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+    remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+    SmallVector<Value, 4> remapOutputs;
+    remapOutputs.reserve(oldMemRefRank);
+    SmallVector<Value, 4> affineApplyOps;
+
+    if (indexRemap &&
+        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+      // Remapped indices.
+      for (auto resultExpr : indexRemap.getResults()) {
+        auto singleResMap = AffineMap::get(
+            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                  remapOperands);
+        remapOutputs.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      // No remapping specified.
+      remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+    }
+
+    SmallVector<Value, 4> newMapOperands;
+    newMapOperands.reserve(newMemRefRank);
+
+    // Prepend 'extraIndices' in 'newMapOperands'.
+    for (Value extraIndex : extraIndices) {
+      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+             "invalid memory op index");
+      newMapOperands.push_back(extraIndex);
+    }
+
+    // Append 'remapOutputs' to 'newMapOperands'.
+    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+    // Create new fully composed AffineMap for new op to be created.
+    assert(newMapOperands.size() == newMemRefRank);
+
+    OperationState state(op->getLoc(), op->getName());
+    // Construct the new operation using this memref.
+    state.operands.reserve(newMapOperands.size() + extraIndices.size());
+    state.operands.push_back(newMemRef);
+
+    // Insert the new memref map operands.
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+    state.types.reserve(op->getNumResults());
+    for (auto result : op->getResults())
+      state.types.push_back(result.getType());
+
+    // Add attribute for 'newMap', other Attributes do not change.
+    // auto newMapAttr = AffineMapAttr::get(newMap);
----------------
Lewuathe wrote:

Is this line necessary?

https://github.com/llvm/llvm-project/pull/107771


More information about the Mlir-commits mailing list