[Mlir-commits] [mlir] 0c93589 - [mlir][Transform] NFC - Add a C++ builder for NamedSequenceOp
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 23 09:39:04 PDT 2023
Author: Nicolas Vasilache
Date: 2023-10-23T16:38:49Z
New Revision: 0c9358948b0478a169da3dabd82a49d15ba2aff4
URL: https://github.com/llvm/llvm-project/commit/0c9358948b0478a169da3dabd82a49d15ba2aff4
DIFF: https://github.com/llvm/llvm-project/commit/0c9358948b0478a169da3dabd82a49d15ba2aff4.diff
LOG: [mlir][Transform] NFC - Add a C++ builder for NamedSequenceOp
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 5bc92e8e954eae7..b14c89eadb097d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -867,6 +867,17 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
+ let builders = [
+ // Build a named sequence.
+ OpBuilder<(ins
+ "StringRef":$symName,
+ "Type":$rootType,
+ "TypeRange":$resultType,
+ "SequenceBodyBuilderFn":$bodyBuilder,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
+ CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
+ ];
+
let extraClassDeclaration = [{
::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
return getFunctionType().getInputs();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7b1badd0adae9ff..8db77b6059dd2e3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1927,6 +1927,48 @@ LogicalResult transform::NamedSequenceOp::verify() {
return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
}
+template <typename FnTy>
+static void buildSequenceBody(OpBuilder &builder, OperationState &state,
+ Type bbArgType, TypeRange extraBindingTypes,
+ FnTy bodyBuilder) {
+ SmallVector<Type> types;
+ types.reserve(1 + extraBindingTypes.size());
+ types.push_back(bbArgType);
+ llvm::append_range(types, extraBindingTypes);
+
+ OpBuilder::InsertionGuard guard(builder);
+ Region *region = state.regions.back().get();
+ Block *bodyBlock =
+ builder.createBlock(region, region->begin(), types,
+ SmallVector<Location>(types.size(), state.location));
+
+ // Populate body.
+ builder.setInsertionPointToStart(bodyBlock);
+ if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
+ bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+ } else {
+ bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
+ bodyBlock->getArguments().drop_front());
+ }
+}
+
+void transform::NamedSequenceOp::build(OpBuilder &builder,
+ OperationState &state, StringRef symName,
+ Type rootType, TypeRange resultTypes,
+ SequenceBodyBuilderFn bodyBuilder,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<DictionaryAttr> argAttrs) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(symName));
+ state.addAttribute(getFunctionTypeAttrName(state.name),
+ TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
+ state.attributes.append(attrs.begin(), attrs.end());
+ state.addRegion();
+
+ buildSequenceBody(builder, state, rootType,
+ /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
@@ -2264,31 +2306,6 @@ void transform::SequenceOp::getRegionInvocationBounds(
bounds.emplace_back(1, 1);
}
-template <typename FnTy>
-static void buildSequenceBody(OpBuilder &builder, OperationState &state,
- Type bbArgType, TypeRange extraBindingTypes,
- FnTy bodyBuilder) {
- SmallVector<Type> types;
- types.reserve(1 + extraBindingTypes.size());
- types.push_back(bbArgType);
- llvm::append_range(types, extraBindingTypes);
-
- OpBuilder::InsertionGuard guard(builder);
- Region *region = state.regions.back().get();
- Block *bodyBlock =
- builder.createBlock(region, region->begin(), types,
- SmallVector<Location>(types.size(), state.location));
-
- // Populate body.
- builder.setInsertionPointToStart(bodyBlock);
- if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
- bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
- } else {
- bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
- bodyBlock->getArguments().drop_front());
- }
-}
-
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
More information about the Mlir-commits
mailing list