[Mlir-commits] [mlir] [mlir] [transform dialect] NamedSequenceOp build: honor arg_attrs when building (PR #168101)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 24 18:08:42 PST 2025
https://github.com/colawithsauce updated https://github.com/llvm/llvm-project/pull/168101
>From 5cebc4877dcb757514e82f85460cc09cd8e6666a Mon Sep 17 00:00:00 2001
From: colawithsauce <51731521+colawithsauce at users.noreply.github.com>
Date: Sat, 15 Nov 2025 02:24:22 +0800
Subject: [PATCH 1/4] [mlir] NamedSequence builder honor arg_attrs when
building
Previously the builder did not attach argument attributes when calling NamedSequenceOp::build with `arg_attrs`, which makes we can not create a legal named_sequence operation by this builder.
This patch ensures `arg_attrs` are stored in the op state
so verification and printing show them.
---
mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 062606e7e10b6..415e2af491c07 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2571,6 +2571,11 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
TypeAttr::get(FunctionType::get(builder.getContext(),
rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
+ if (!argAttrs.empty()) {
+ SmallVector<Attribute> argAttrsVec(argAttrs.begin(), argAttrs.end());
+ state.getOrAddProperties<Properties>().arg_attrs =
+ ArrayAttr::get(builder.getContext(), argAttrsVec);
+ }
state.addRegion();
buildSequenceBody(builder, state, rootType,
>From a44361c3e4af1e3161b6bad3cbf609947c7ed3df Mon Sep 17 00:00:00 2001
From: colawithsauce <51731521+colawithsauce at users.noreply.github.com>
Date: Tue, 18 Nov 2025 03:27:33 +0000
Subject: [PATCH 2/4] [mlir][transform-dialect] add unittest of named_sequence
build.
---
.../Dialect/Transform/CMakeLists.txt | 1 +
.../TransformNamedSequenceCreate.cpp | 45 +++++++++++++++++++
2 files changed, 46 insertions(+)
create mode 100644 mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 20cdc63966ec0..6bffd71d82fc2 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRTransformDialectTests
+ TransformNamedSequenceCreate.cpp
BuildOnlyExtensionTest.cpp
Preload.cpp
)
diff --git a/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
new file mode 100644
index 0000000000000..849b2b56021a8
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
@@ -0,0 +1,45 @@
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+TEST(NamedSequenceOpTest, ArgAttrsAreHonoredByBuilder) {
+ MLIRContext ctx;
+ ctx.loadDialect<TransformDialect>();
+
+ OpBuilder builder(&ctx);
+ auto module = ModuleOp::create(UnknownLoc::get(&ctx));
+ builder.setInsertionPointToEnd(module.getBody());
+
+ Location loc = UnknownLoc::get(&ctx);
+
+ static constexpr StringLiteral kMainSequenceName = "__transform_main";
+
+ NamedSequenceOp seqOp = builder.create<NamedSequenceOp>(
+ loc,
+ /*sym_name=*/kMainSequenceName,
+ /*rootType=*/builder.getType<AnyOpType>(),
+ /*resultType=*/TypeRange{},
+ [](OpBuilder &b, Location nested, Value rootH) {
+ b.create<YieldOp>(nested, ValueRange());
+ },
+ /*args=*/ArrayRef<NamedAttribute>{},
+ /*attrArgs=*/
+ ArrayRef<DictionaryAttr>{
+ builder.getDictionaryAttr(ArrayRef<NamedAttribute>{
+ builder.getNamedAttr(TransformDialect::kArgConsumedAttrName,
+ builder.getUnitAttr())})});
+
+ // 检查 body argument 上有没有 transform.consumed
+ Block &body = seqOp.getBody().front();
+ ASSERT_EQ(body.getNumArguments(), 1u);
+
+ StringAttr arg0Name = seqOp.getArgAttrsAttrName();
+ EXPECT_TRUE(arg0Name);
+}
\ No newline at end of file
>From 298aa73b50ce76021740893c19f4876c328b267e Mon Sep 17 00:00:00 2001
From: colawithsauce <51731521+colawithsauce at users.noreply.github.com>
Date: Fri, 21 Nov 2025 06:08:21 +0000
Subject: [PATCH 3/4] [mlir][transform-dialect] comment use english & add
trailing new line.
---
.../Dialect/Transform/TransformNamedSequenceCreate.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
index 849b2b56021a8..69f043cb113fc 100644
--- a/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
+++ b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
@@ -36,10 +36,10 @@ TEST(NamedSequenceOpTest, ArgAttrsAreHonoredByBuilder) {
builder.getNamedAttr(TransformDialect::kArgConsumedAttrName,
builder.getUnitAttr())})});
- // 检查 body argument 上有没有 transform.consumed
+ // check if body argument contains any attributes
Block &body = seqOp.getBody().front();
ASSERT_EQ(body.getNumArguments(), 1u);
StringAttr arg0Name = seqOp.getArgAttrsAttrName();
EXPECT_TRUE(arg0Name);
-}
\ No newline at end of file
+}
>From 732adc55d4ffeb2077d5c04ef4519ee82db6f195 Mon Sep 17 00:00:00 2001
From: colawithsauce <cola_with_sauce at foxmail.com>
Date: Sat, 22 Nov 2025 14:37:50 +0800
Subject: [PATCH 4/4] [mlir][transform-dialect] Fix named sequence build op
testing script.
---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 2 +-
.../TransformNamedSequenceCreate.cpp | 30 +++++++++++--------
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 415e2af491c07..2285e3ddcc981 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2572,7 +2572,7 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
if (!argAttrs.empty()) {
- SmallVector<Attribute> argAttrsVec(argAttrs.begin(), argAttrs.end());
+ SmallVector<Attribute> argAttrsVec(argAttrs);
state.getOrAddProperties<Properties>().arg_attrs =
ArrayAttr::get(builder.getContext(), argAttrsVec);
}
diff --git a/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
index 69f043cb113fc..1dc1cba33d92e 100644
--- a/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
+++ b/mlir/unittests/Dialect/Transform/TransformNamedSequenceCreate.cpp
@@ -14,20 +14,17 @@ TEST(NamedSequenceOpTest, ArgAttrsAreHonoredByBuilder) {
ctx.loadDialect<TransformDialect>();
OpBuilder builder(&ctx);
- auto module = ModuleOp::create(UnknownLoc::get(&ctx));
- builder.setInsertionPointToEnd(module.getBody());
-
Location loc = UnknownLoc::get(&ctx);
+ auto module = ModuleOp::create(loc);
+ builder.setInsertionPointToEnd(module.getBody());
- static constexpr StringLiteral kMainSequenceName = "__transform_main";
-
- NamedSequenceOp seqOp = builder.create<NamedSequenceOp>(
- loc,
- /*sym_name=*/kMainSequenceName,
+ NamedSequenceOp seqOp = NamedSequenceOp::create(
+ builder, loc,
+ /*sym_name=*/transform::TransformDialect::kTransformEntryPointSymbolName,
/*rootType=*/builder.getType<AnyOpType>(),
/*resultType=*/TypeRange{},
[](OpBuilder &b, Location nested, Value rootH) {
- b.create<YieldOp>(nested, ValueRange());
+ YieldOp::create(b, nested, ValueRange());
},
/*args=*/ArrayRef<NamedAttribute>{},
/*attrArgs=*/
@@ -36,10 +33,19 @@ TEST(NamedSequenceOpTest, ArgAttrsAreHonoredByBuilder) {
builder.getNamedAttr(TransformDialect::kArgConsumedAttrName,
builder.getUnitAttr())})});
- // check if body argument contains any attributes
+ // Check if body argument contains any attributes.
Block &body = seqOp.getBody().front();
ASSERT_EQ(body.getNumArguments(), 1u);
- StringAttr arg0Name = seqOp.getArgAttrsAttrName();
- EXPECT_TRUE(arg0Name);
+ auto arg0Attr = seqOp.getArgAttrDict(0);
+ EXPECT_TRUE(arg0Attr);
+
+ auto arg0Name = arg0Attr.getNamed(TransformDialect::kArgConsumedAttrName);
+ EXPECT_TRUE(arg0Name.has_value());
+
+ EXPECT_EQ(arg0Name.value().getName(), TransformDialect::kArgConsumedAttrName);
+
+ auto expectedFalse =
+ arg0Attr.getNamed(TransformDialect::kArgReadOnlyAttrName);
+ EXPECT_FALSE(expectedFalse.has_value());
}
More information about the Mlir-commits
mailing list