[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)

Amir Bishara llvmlistbot at llvm.org
Sun Sep 14 13:20:38 PDT 2025


================
@@ -90,33 +152,105 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
 }
 
 func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
-  assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
-  assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
-  SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
-  SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+                                  ArrayRef<int> oldArgIdxToNewArgIdx,
+                                  ArrayRef<int> oldResIdxToNewResIdx) {
+  assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
+         "oldArgIdxToNewArgIdx must match the number of operands in the call "
+         "operation");
+  assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+         "oldResIdxToNewResIdx must match the number of results in the call "
+         "operation");
+
+  SmallVector<Value> origOperands = callOp.getOperands();
+  SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+      getInverseMapping(oldArgIdxToNewArgIdx);
+  SmallVector<Value> newOperandsValues =
+      getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
+  SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+      getInverseMapping(oldResIdxToNewResIdx);
+  SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
+  SmallVector<Type> newResultTypes =
+      getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
 
   // Replace the kernel call operation with a new one that has the
-  // reordered arguments.
+  // mapped arguments.
   rewriter.setInsertionPoint(callOp);
   auto newCallOp =
       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
-                           newResultTypes, newArgsOrderValues);
+                           newResultTypes, newOperandsValues);
   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
-  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
-    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
-                                newCallOp.getResult(newIndex));
+  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
+    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+                                newCallOp.getResult(newResIdx));
   rewriter.eraseOp(callOp);
 
   return newCallOp;
 }
+
+FailureOr<std::pair<func::FuncOp, func::CallOp>>
+func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
+                              ModuleOp moduleOp, std::string &errorMessage) {
+  SmallVector<func::CallOp> callOps;
+  auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
+    if (callOp.getCallee() == funcOp.getSymName()) {
+      if (!callOps.empty())
+        // Only support one callOp for now
+        return WalkResult::interrupt();
+      callOps.push_back(callOp);
+    }
+    return WalkResult::advance();
+  });
+
+  if (traversalResult.wasInterrupted()) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' has more than one callOp";
+    return failure();
+  }
+
+  if (callOps.empty()) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' does not have any callOp";
+    return failure();
+  }
+
+  func::CallOp callOp = callOps.front();
+
+  // Create mapping for arguments (deduplicate operands)
+  SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+  llvm::DenseMap<Value, int> valueToNewArgIdx;
+  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+    auto [iterator, inserted] = valueToNewArgIdx.insert(
+        {operand, static_cast<int>(valueToNewArgIdx.size())});
+    // Reduce the duplicate operands and maintain the original order.
+    oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
+  }
+
+  bool hasDuplicateOperands =
+      valueToNewArgIdx.size() != callOp.getNumOperands();
+  if (!hasDuplicateOperands) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' does not have duplicate operands";
+    return failure();
----------------
amirBish wrote:

Replaced with `LDBG()`.

https://github.com/llvm/llvm-project/pull/158266


More information about the Mlir-commits mailing list