[Mlir-commits] [mlir] 0c93589 - [mlir][Transform] NFC - Add a C++ builder for NamedSequenceOp

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 23 09:39:04 PDT 2023


Author: Nicolas Vasilache
Date: 2023-10-23T16:38:49Z
New Revision: 0c9358948b0478a169da3dabd82a49d15ba2aff4

URL: https://github.com/llvm/llvm-project/commit/0c9358948b0478a169da3dabd82a49d15ba2aff4
DIFF: https://github.com/llvm/llvm-project/commit/0c9358948b0478a169da3dabd82a49d15ba2aff4.diff

LOG: [mlir][Transform] NFC - Add a C++ builder for NamedSequenceOp

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 5bc92e8e954eae7..b14c89eadb097d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -867,6 +867,17 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
+  let builders = [
+    // Build a named sequence.
+    OpBuilder<(ins
+      "StringRef":$symName,
+      "Type":$rootType,
+      "TypeRange":$resultType,
+      "SequenceBodyBuilderFn":$bodyBuilder,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
+      CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
+  ];
+
   let extraClassDeclaration = [{
     ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
       return getFunctionType().getInputs();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7b1badd0adae9ff..8db77b6059dd2e3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1927,6 +1927,48 @@ LogicalResult transform::NamedSequenceOp::verify() {
   return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
 }
 
+template <typename FnTy>
+static void buildSequenceBody(OpBuilder &builder, OperationState &state,
+                              Type bbArgType, TypeRange extraBindingTypes,
+                              FnTy bodyBuilder) {
+  SmallVector<Type> types;
+  types.reserve(1 + extraBindingTypes.size());
+  types.push_back(bbArgType);
+  llvm::append_range(types, extraBindingTypes);
+
+  OpBuilder::InsertionGuard guard(builder);
+  Region *region = state.regions.back().get();
+  Block *bodyBlock =
+      builder.createBlock(region, region->begin(), types,
+                          SmallVector<Location>(types.size(), state.location));
+
+  // Populate body.
+  builder.setInsertionPointToStart(bodyBlock);
+  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
+    bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+  } else {
+    bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
+                bodyBlock->getArguments().drop_front());
+  }
+}
+
+void transform::NamedSequenceOp::build(OpBuilder &builder,
+                                       OperationState &state, StringRef symName,
+                                       Type rootType, TypeRange resultTypes,
+                                       SequenceBodyBuilderFn bodyBuilder,
+                                       ArrayRef<NamedAttribute> attrs,
+                                       ArrayRef<DictionaryAttr> argAttrs) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(symName));
+  state.addAttribute(getFunctionTypeAttrName(state.name),
+                     TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addRegion();
+
+  buildSequenceBody(builder, state, rootType,
+                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOp
 //===----------------------------------------------------------------------===//
@@ -2264,31 +2306,6 @@ void transform::SequenceOp::getRegionInvocationBounds(
   bounds.emplace_back(1, 1);
 }
 
-template <typename FnTy>
-static void buildSequenceBody(OpBuilder &builder, OperationState &state,
-                              Type bbArgType, TypeRange extraBindingTypes,
-                              FnTy bodyBuilder) {
-  SmallVector<Type> types;
-  types.reserve(1 + extraBindingTypes.size());
-  types.push_back(bbArgType);
-  llvm::append_range(types, extraBindingTypes);
-
-  OpBuilder::InsertionGuard guard(builder);
-  Region *region = state.regions.back().get();
-  Block *bodyBlock =
-      builder.createBlock(region, region->begin(), types,
-                          SmallVector<Location>(types.size(), state.location));
-
-  // Populate body.
-  builder.setInsertionPointToStart(bodyBlock);
-  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
-    bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
-  } else {
-    bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
-                bodyBlock->getArguments().drop_front());
-  }
-}
-
 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                   TypeRange resultTypes,
                                   FailurePropagationMode failurePropagationMode,


        


More information about the Mlir-commits mailing list