[Mlir-commits] [mlir] [mlir] Add option to cloning for different results (PR #184202)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 10:47:21 PST 2026
https://github.com/tdegioanni-nvidia created https://github.com/llvm/llvm-project/pull/184202
With his permission while he is away, I am resurrecting @zero9178's very first MLIR PR #65171 that adds an option to change result in the operation cloning mechanism. This is very useful as cloning is currently the only way to change the results of an operation.
> Since Operations cannot change the results after creation, a clone is necessary to create new results. Doing such an operation generically has not been possible so far. This PR therefore adds a new option to the CloneOptions struct allowing adding changing the results of the created operation.
>
> The caller is responsible to ensure that this is a valid operation and setting the IRMapping accordingly afterwards if required.
>From 6b9eba81423cac15991bd0317e869c4d258f47f2 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Mon, 2 Mar 2026 19:34:06 +0100
Subject: [PATCH 1/2] [mlir] Add option to cloning for different results
Co-authored-by: zero9178 <markus.boeck02 at gmail.com>
---
mlir/include/mlir/IR/Operation.h | 37 +++++++++++++++++-----
mlir/lib/IR/Operation.cpp | 31 ++++++++++++------
mlir/unittests/IR/OperationSupportTest.cpp | 16 ++++++++++
3 files changed, 67 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..5e40b7b865fe1 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -140,18 +140,20 @@ class alignas(8) Operation final
/// * Whether cloning should recursively traverse into the regions of the
/// operation or not.
/// * Whether cloning should also clone the operands of the operation.
+ /// * Whether to use different result types or clone them.
class CloneOptions {
public:
/// Default constructs an option with all flags set to false. That means all
/// parts of an operation that may optionally not be cloned, are not cloned.
CloneOptions();
- /// Constructs an instance with the clone regions and clone operands flags
- /// set accordingly.
- CloneOptions(bool cloneRegions, bool cloneOperands);
+ /// Constructs an instance with the options set accordingly.
+ CloneOptions(bool cloneRegions, bool cloneOperands,
+ std::optional<SmallVector<Type>> resultTypes);
- /// Returns an instance with all flags set to true. This is the default
- /// when using the clone method and clones all parts of the operation.
+ /// Returns an instance such that all elements of the operation are cloned.
+ /// This is the default when using the clone method and clones all parts of
+ /// the operation.
static CloneOptions all();
/// Configures whether cloning should traverse into any of the regions of
@@ -172,11 +174,29 @@ class alignas(8) Operation final
/// Returns whether operands should be cloned as well.
bool shouldCloneOperands() const { return cloneOperandsFlag; }
+ /// Configures different result types to use for the cloned operation.
+ /// If an empty optional, the result types are cloned from the original
+ /// operation.
+ CloneOptions &withResultTypes(std::optional<SmallVector<Type>> resultTypes);
+
+ /// Returns true if the results are cloned from the operation.
+ bool shouldCloneResults() const { return !resultTypes.has_value(); }
+
+ /// Returns the result types that should be used for the created operation
+ /// or `defaultResultTypes` if none were set.
+ TypeRange resultTypesOr(TypeRange defaultResultTypes) const {
+ if (resultTypes)
+ return *resultTypes;
+ return defaultResultTypes;
+ }
+
private:
/// Whether regions should be cloned.
bool cloneRegionsFlag : 1;
/// Whether operands should be cloned.
bool cloneOperandsFlag : 1;
+ /// New result types to use in the cloned operation.
+ std::optional<SmallVector<Type>> resultTypes;
};
/// Create a deep copy of this operation, remapping any operands that use
@@ -185,7 +205,8 @@ class alignas(8) Operation final
/// sub-operations to the corresponding operation that is copied, and adds
/// those mappings to the map.
/// Optionally, one may configure what parts of the operation to clone using
- /// the options parameter.
+ /// the options parameter. If parts of the operation (e.g. results or regions)
+ /// are not cloned, they will not appear in the mapper.
///
/// Calling this method from multiple threads is generally safe if through the
/// process of cloning no new uses of 'Value's from outside the operation are
@@ -194,8 +215,8 @@ class alignas(8) Operation final
/// mapper, it is possible to avoid adding uses to outside operands by
/// remapping them to 'Value's owned by the caller thread.
Operation *clone(IRMapping &mapper,
- CloneOptions options = CloneOptions::all());
- Operation *clone(CloneOptions options = CloneOptions::all());
+ const CloneOptions &options = CloneOptions::all());
+ Operation *clone(const CloneOptions &options = CloneOptions::all());
/// Create a partial copy of this operation without traversing into attached
/// regions. The new operation will have the same number of regions as the
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index bf8a918641dfb..46030936a3910 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -674,13 +674,18 @@ InFlightDiagnostic Operation::emitOpError(const Twine &message) {
//===----------------------------------------------------------------------===//
Operation::CloneOptions::CloneOptions()
- : cloneRegionsFlag(false), cloneOperandsFlag(false) {}
+ : cloneRegionsFlag(false), cloneOperandsFlag(false),
+ resultTypes(std::nullopt) {}
-Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands)
- : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {}
+Operation::CloneOptions::CloneOptions(
+ bool cloneRegions, bool cloneOperands,
+ std::optional<SmallVector<Type>> resultTypes)
+ : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands),
+ resultTypes(resultTypes) {}
Operation::CloneOptions Operation::CloneOptions::all() {
- return CloneOptions().cloneRegions().cloneOperands();
+ return CloneOptions().cloneRegions().cloneOperands().withResultTypes(
+ std::nullopt);
}
Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) {
@@ -693,6 +698,12 @@ Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) {
return *this;
}
+Operation::CloneOptions &Operation::CloneOptions::withResultTypes(
+ std::optional<SmallVector<Type>> resultTypes) {
+ this->resultTypes = std::move(resultTypes);
+ return *this;
+}
+
/// Create a deep copy of this operation but keep the operation regions empty.
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
/// to contain the results. The `mapResults` flag specifies whether the results
@@ -711,7 +722,7 @@ Operation *Operation::cloneWithoutRegions() {
/// them alone if no entry is present). Replaces references to cloned
/// sub-operations to the corresponding operation that is copied, and adds
/// those mappings to the map.
-Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
+Operation *Operation::clone(IRMapping &mapper, const CloneOptions &options) {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
@@ -728,7 +739,8 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
successors.push_back(mapper.lookupOrDefault(successor));
// Create the new operation.
- auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
+ auto *newOp = create(getLoc(), getName(),
+ options.resultTypesOr(getResultTypes()), operands, attrs,
getPropertiesStorage(), successors, getNumRegions());
mapper.map(this, newOp);
@@ -739,13 +751,14 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
}
// Remember the mapping of any results.
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- mapper.map(getResult(i), newOp->getResult(i));
+ if (options.shouldCloneResults())
+ for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+ mapper.map(getResult(i), newOp->getResult(i));
return newOp;
}
-Operation *Operation::clone(CloneOptions options) {
+Operation *Operation::clone(const CloneOptions &options) {
IRMapping mapper;
return clone(mapper, options);
}
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9f3e7ed34a27d..73babc4c4fbb7 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -356,4 +356,20 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
opWithProperty2->destroy();
}
+TEST(OperationCloneTest, CloneWithDifferentResults) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *useOp = createOp(&context, std::nullopt, builder.getI32Type());
+ IRMapping map;
+ Operation *cloneOp = useOp->clone(
+ map, Operation::CloneOptions::all().withResultTypes(
+ SmallVector<Type>{builder.getI32Type(), builder.getI16Type()}));
+
+ ASSERT_EQ(cloneOp->getNumResults(), 2u);
+ EXPECT_EQ(cloneOp->getResult(0).getType(), builder.getI32Type());
+ EXPECT_EQ(cloneOp->getResult(1).getType(), builder.getI16Type());
+ EXPECT_FALSE(map.contains(useOp->getResult(0)));
+}
+
} // namespace
>From 84164758fb50cf000f668a83a99eda1528acbf44 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Mon, 2 Mar 2026 19:43:46 +0100
Subject: [PATCH 2/2] update test
---
mlir/unittests/IR/OperationSupportTest.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 73babc4c4fbb7..fc567bc4a8048 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -360,7 +360,7 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
MLIRContext context;
Builder builder(&context);
- Operation *useOp = createOp(&context, std::nullopt, builder.getI32Type());
+ Operation *useOp = createOp(&context, {}, builder.getI32Type());
IRMapping map;
Operation *cloneOp = useOp->clone(
map, Operation::CloneOptions::all().withResultTypes(
More information about the Mlir-commits
mailing list