[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)
Kai Sasaki
llvmlistbot at llvm.org
Tue Sep 10 00:02:00 PDT 2024
================
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// is set.
return failure();
}
- op->setOperand(memRefOperandPos, newMemRef);
+
+ // Check if it is a memref.load
+ auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+ bool isReductionLike =
+ indexRemap.getNumResults() < indexRemap.getNumInputs();
+ if (!memrefLoad || !isReductionLike) {
+ op->setOperand(memRefOperandPos, newMemRef);
+ return success();
+ }
+
+ unsigned oldMapNumInputs = oldMemRefRank;
+ SmallVector<Value, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+ SmallVector<Value, 4> oldMemRefOperands;
+ oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+ SmallVector<Value, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+ SmallVector<Value, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+ SmallVector<Value, 4> affineApplyOps;
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = AffineMap::get(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (Value extraIndex : extraIndices) {
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+
+ OperationState state(op->getLoc(), op->getName());
+ // Construct the new operation using this memref.
+ state.operands.reserve(newMapOperands.size() + extraIndices.size());
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ state.types.reserve(op->getNumResults());
+ for (auto result : op->getResults())
+ state.types.push_back(result.getType());
+
+ // Add attribute for 'newMap', other Attributes do not change.
+ // auto newMapAttr = AffineMapAttr::get(newMap);
----------------
Lewuathe wrote:
Is this line necessary?
https://github.com/llvm/llvm-project/pull/107771
More information about the Mlir-commits
mailing list