[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 12 03:47:18 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-func
Author: Amir Bishara (amirBish)
<details>
<summary>Changes</summary>
This PR adds a new transform operation which removes the duplicate arguments from the function operation based on the callOp of this function.
To have a more simple implementation for now, the transform will fail when having multiple callOps for the same function we want to eliminate the different arguments from.
This pull request also adpat the utils under the func dialect to be reusable also for this transformOp.
---
Patch is 31.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158266.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td (+26)
- (modified) mlir/include/mlir/Dialect/Func/Utils/Utils.h (+19-12)
- (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+95-6)
- (modified) mlir/lib/Dialect/Func/Utils/Utils.cpp (+121-38)
- (modified) mlir/test/Dialect/Func/func-transform-invalid.mlir (+108)
- (modified) mlir/test/Dialect/Func/func-transform.mlir (+72)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 4062f310c6521..b64b3fcdb275b 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
}];
}
+def DeduplicateFuncArgsOp
+ : Op<Transform_Dialect, "func.deduplicate_func_args",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ This transform takes a module and a function name, and deduplicates
+ the arguments of the function. The function is expected to be defined in
+ the module.
+
+ This transform will emit a silenceable failure if:
+ - The function with the given name does not exist in the module.
+ - The function does not have duplicate arguments.
+ - The function does not have a single call.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$module,
+ SymbolRefAttr:$function_name);
+ let results = (outs TransformHandleTypeInterface:$transformed_module,
+ TransformHandleTypeInterface:$transformed_function);
+
+ let assemblyFormat = [{
+ $function_name
+ `at` $module attr-dict `:` functional-type(operands, results)
+ }];
+}
+
#endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 2e8b6723a0e53..464ebc1305d60 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -17,7 +17,7 @@
#define MLIR_DIALECT_FUNC_UTILS_H
#include "mlir/IR/PatternMatch.h"
-#include "llvm/ADT/ArrayRef.h"
+#include <map>
namespace mlir {
@@ -27,21 +27,28 @@ class FuncOp;
class CallOp;
/// Creates a new function operation with the same name as the original
-/// function operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
+/// function operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
/// The `funcOp` operation must have exactly one block.
/// Returns the new function operation or failure if `funcOp` doesn't
/// have exactly one block.
-FailureOr<FuncOp>
-replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
- llvm::ArrayRef<unsigned> newArgsOrder,
- llvm::ArrayRef<unsigned> newResultsOrder);
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole function arguments and results.
+mlir::FailureOr<mlir::func::FuncOp>
+replaceFuncWithNewMapping(mlir::RewriterBase &rewriter,
+ mlir::func::FuncOp funcOp,
+ const std::map<unsigned, unsigned> &oldArgToNewArg,
+ const std::map<unsigned, unsigned> &oldResToNewRes);
/// Creates a new call operation with the values as the original
-/// call operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
-CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
- llvm::ArrayRef<unsigned> newArgsOrder,
- llvm::ArrayRef<unsigned> newResultsOrder);
+/// call operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole call operation arguments and results.
+mlir::func::CallOp
+replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter,
+ mlir::func::CallOp callOp,
+ const std::map<unsigned, unsigned> &oldArgToNewArg,
+ const std::map<unsigned, unsigned> &oldResToNewRes);
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 935d3e5ac331b..486dad09f9392 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
using namespace mlir;
@@ -296,9 +297,15 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
}
}
- FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
- rewriter, funcOp, argsInterchange.getArrayRef(),
- resultsInterchange.getArrayRef());
+ std::map<unsigned, unsigned> oldArgToNewArg, oldResToNewRes;
+ for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(argsInterchange))
+ oldArgToNewArg[oldArgIdx] = newArgIdx;
+
+ for (auto [oldResIdx, newResIdx] : llvm::enumerate(resultsInterchange))
+ oldResToNewRes[oldResIdx] = newResIdx;
+
+ FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+ rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
if (failed(newFuncOpOrFailure))
return emitSilenceableFailure(getLoc())
<< "failed to replace function signature '" << getFunctionName()
@@ -312,9 +319,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
});
for (func::CallOp callOp : callOps)
- func::replaceCallOpWithNewOrder(rewriter, callOp,
- argsInterchange.getArrayRef(),
- resultsInterchange.getArrayRef());
+ func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
+ oldResToNewRes);
}
results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +336,89 @@ void transform::ReplaceFuncSignatureOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(payloadOps))
+ return emitDefiniteFailure() << "requires a single module to operate on";
+
+ auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+ if (!targetModuleOp)
+ return emitSilenceableFailure(getLoc())
+ << "target is expected to be module operation";
+
+ func::FuncOp funcOp =
+ targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+ if (!funcOp)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' is not found";
+
+ SmallVector<func::CallOp> callOps;
+ targetModuleOp.walk([&](func::CallOp callOp) {
+ if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+ callOps.push_back(callOp);
+ });
+
+ // TODO: Support more than one callOp.
+ if (!llvm::hasSingleElement(callOps))
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName()
+ << "' does not have a single callOp";
+
+ llvm::DenseSet<Value> seenValues;
+ func::CallOp callOp = callOps.front();
+ bool hasDuplicatesOperands =
+ llvm::any_of(callOp.getOperands(), [&seenValues](Value operand) {
+ return !seenValues.insert(operand).second;
+ });
+
+ if (!hasDuplicatesOperands)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName()
+ << "' does not have duplicate operands";
+
+ std::map<unsigned, unsigned> oldArgIdxToNewArgIdx;
+ llvm::DenseMap<Value, unsigned> valueToNewArgIdx;
+ for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+ if (!valueToNewArgIdx.count(operand))
+ valueToNewArgIdx[operand] = valueToNewArgIdx.size();
+ // Reduce the duplicate operands and maintain the original order.
+ oldArgIdxToNewArgIdx[operandIdx] = valueToNewArgIdx[operand];
+ }
+
+ std::map<unsigned, unsigned> oldResIdxToNewResIdx;
+ for (unsigned resultIdx = 0; resultIdx < callOp.getNumResults(); ++resultIdx)
+ oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+ FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+ rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+ if (failed(newFuncOpOrFailure))
+ return emitSilenceableFailure(getLoc())
+ << "failed to deduplicate function arguments '" << getFunctionName()
+ << "'";
+
+ func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgIdxToNewArgIdx,
+ oldResIdxToNewResIdx);
+
+ results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+ results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::DeduplicateFuncArgsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getModuleMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index f781ed2d591b4..a58eb7233f460 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -18,16 +18,16 @@
using namespace mlir;
-FailureOr<func::FuncOp>
-func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
- ArrayRef<unsigned> newArgsOrder,
- ArrayRef<unsigned> newResultsOrder) {
+FailureOr<func::FuncOp> func::replaceFuncWithNewMapping(
+ RewriterBase &rewriter, func::FuncOp funcOp,
+ const std::map<unsigned, unsigned> &oldArgToNewArg,
+ const std::map<unsigned, unsigned> &oldResToNewRes) {
// Generate an empty new function operation with the same name as the
// original.
- assert(funcOp.getNumArguments() == newArgsOrder.size() &&
- "newArgsOrder must match the number of arguments in the function");
- assert(funcOp.getNumResults() == newResultsOrder.size() &&
- "newResultsOrder must match the number of results in the function");
+ assert(funcOp.getNumArguments() == oldArgToNewArg.size() &&
+ "oldArgToNewArg must match the number of arguments in the function");
+ assert(funcOp.getNumResults() == oldResToNewRes.size() &&
+ "oldResToNewRes must match the number of results in the function");
if (!funcOp.getBody().hasOneBlock())
return rewriter.notifyMatchFailure(
@@ -37,12 +37,49 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
SmallVector<Type> newInputTypes, newOutputTypes;
SmallVector<Location> locs;
- for (unsigned int idx : newArgsOrder) {
- newInputTypes.push_back(origInputTypes[idx]);
- locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
+
+ std::map<unsigned, SmallVector<unsigned>> newArgToOldArg;
+ for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+ newArgToOldArg[newArgIdx].push_back(oldArgIdx);
+
+ for (auto [newArgIdx, oldArgIdx] : newArgToOldArg) {
+ std::ignore = newArgIdx;
+ assert(llvm::all_of(oldArgIdx,
+ [&funcOp](unsigned idx) -> bool {
+ return idx < funcOp.getNumArguments();
+ }) &&
+ "idx must be less than the number of arguments in the function");
+ assert(!oldArgIdx.empty() && "oldArgIdx must not be empty");
+ Type origInputTypeToCheck = origInputTypes[oldArgIdx.front()];
+ assert(llvm::all_of(oldArgIdx,
+ [&](unsigned idx) -> bool {
+ return origInputTypes[idx] == origInputTypeToCheck;
+ }) &&
+ "all oldArgIdx must have the same type");
+ newInputTypes.push_back(origInputTypeToCheck);
+ locs.push_back(funcOp.getArgument(oldArgIdx.front()).getLoc());
+ }
+
+ std::map<unsigned, SmallVector<unsigned>> newResToOldRes;
+ for (auto [oldResIdx, newResIdx] : oldResToNewRes)
+ newResToOldRes[newResIdx].push_back(oldResIdx);
+
+ for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+ std::ignore = newResIdx;
+ assert(llvm::all_of(oldResIdx,
+ [&funcOp](unsigned idx) -> bool {
+ return idx < funcOp.getNumResults();
+ }) &&
+ "idx must be less than the number of results in the function");
+ Type origOutputTypeToCheck = origOutputTypes[oldResIdx.front()];
+ assert(llvm::all_of(oldResIdx,
+ [&](unsigned idx) -> bool {
+ return origOutputTypes[idx] == origOutputTypeToCheck;
+ }) &&
+ "all oldResIdx must have the same type");
+ newOutputTypes.push_back(origOutputTypeToCheck);
}
- for (unsigned int idx : newResultsOrder)
- newOutputTypes.push_back(origOutputTypes[idx]);
+
rewriter.setInsertionPoint(funcOp);
auto newFuncOp = func::FuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(),
@@ -58,14 +95,15 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
IRMapping operandMapper;
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
funcOp.getAllArgAttrs(argAttrs);
- for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
- operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
- newFuncOp.getArgument(i));
- newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
- }
+ for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+ operandMapper.map(funcOp.getArgument(oldArgIdx),
+ newFuncOp.getArgument(newArgIdx));
+ for (auto [newArgIdx, oldArgIdx] : newArgToOldArg)
+ newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
+
funcOp.getAllResultAttrs(resultAttrs);
- for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
- newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+ for (auto [newResIdx, oldResIdx] : newResToOldRes)
+ newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
// Clone the operations from the original function to the new function.
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
@@ -76,8 +114,10 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
auto returnOp = cast<func::ReturnOp>(
newFuncOp.getFunctionBody().begin()->getTerminator());
SmallVector<Value> newReturnValues;
- for (unsigned int idx : newResultsOrder)
- newReturnValues.push_back(returnOp.getOperand(idx));
+ for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+ std::ignore = newResIdx;
+ newReturnValues.push_back(returnOp.getOperand(oldResIdx.front()));
+ }
rewriter.setInsertionPoint(returnOp);
auto newReturnOp =
func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
@@ -89,33 +129,76 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
return newFuncOp;
}
-func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
- ArrayRef<unsigned> newArgsOrder,
- ArrayRef<unsigned> newResultsOrder) {
+func::CallOp func::replaceCallOpWithNewMapping(
+ RewriterBase &rewriter, func::CallOp callOp,
+ const std::map<unsigned, unsigned> &oldArgToNewArg,
+ const std::map<unsigned, unsigned> &oldResToNewRes) {
assert(
- callOp.getNumOperands() == newArgsOrder.size() &&
- "newArgsOrder must match the number of operands in the call operation");
+ callOp.getNumOperands() == oldArgToNewArg.size() &&
+ "oldArgToNewArg must match the number of operands in the call operation");
assert(
- callOp.getNumResults() == newResultsOrder.size() &&
- "newResultsOrder must match the number of results in the call operation");
+ callOp.getNumResults() == oldResToNewRes.size() &&
+ "oldResToNewRes must match the number of results in the call operation");
+
+ // Inverse mapping from new arguments to old arguments.
+ std::map<unsigned, SmallVector<unsigned>> newArgToOldArg;
+ for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+ newArgToOldArg[newArgIdx].push_back(oldArgIdx);
+
SmallVector<Value> newArgsOrderValues;
- for (unsigned int argIdx : newArgsOrder)
- newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+ for (const auto &[newArgIdx, oldArgIdx] : newArgToOldArg) {
+ std::ignore = newArgIdx;
+ assert(
+ llvm::all_of(oldArgIdx,
+ [&callOp](unsigned idx) -> bool {
+ return idx < callOp.getNumOperands();
+ }) &&
+ "idx must be less than the number of operands in the call operation");
+ assert(!oldArgIdx.empty() && "oldArgIdx must not be empty");
+ Value origOperandToCheck = callOp.getOperand(oldArgIdx.front());
+ assert(llvm::all_of(oldArgIdx,
+ [&](unsigned idx) -> bool {
+ return callOp.getOperand(idx).getType() ==
+ origOperandToCheck.getType();
+ }) &&
+ "all oldArgIdx must have the same type");
+ newArgsOrderValues.push_back(origOperandToCheck);
+ }
+
SmallVector<Type> newResultTypes;
- for (unsigned int resIdx : newResultsOrder)
- newResultTypes.push_back(callOp.getResult(resIdx).getType());
+ std::map<unsigned, SmallVector<unsigned>> newResToOldRes;
+ for (auto [oldResIdx, newResIdx] : oldResToNewRes)
+ newResToOldRes[newResIdx].push_back(oldResIdx);
+
+ for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+ std::ignore = newResIdx;
+ assert(llvm::all_of(oldResIdx,
+ [&callOp](unsigned idx) -> bool {
+ return idx < callOp.getNumResults();
+ }) &&
+ "idx must be less than the number of results in the call operation");
+ assert(!oldResIdx.empty() && "oldResIdx must not be empty");
+ Value origResultToCheck = callOp.getResult(oldResIdx.front());
+ assert(llvm::all_of(oldResIdx,
+ [&](unsigned idx) -> bool {
+ return callOp.getResult(idx).getType() ==
+ origResultToCheck.getType();
+ }) &&
+ "all oldResIdx must have the same type");
+ newResultTypes.push_back(origResultToCheck.getType());
+ }
// Replace the kernel call operation with a new one that has the
- // reordered arguments.
+ // mapped arguments.
rewriter.setInsertionPoint(callOp);
auto newCallOp =
func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
newResultTypes, newArgsOrderValues);
+ newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
- for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
- rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
- newCallOp.getResult(newIndex));
+ for (auto &&[oldResIdx, newResIdx] : oldResToNewRes)
+ rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+ newCallOp.getResult(newResIdx));
rewriter.eraseOp(callOp);
return newCallOp;
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
index e712eee83f36e..d260a36a723b6 100644
--- a/mlir/test/Dialect/Func/func-transform-invalid.mlir
+++ b/mlir/test/Dialect/Func/func-transform-i...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/158266
More information about the Mlir-commits
mailing list