[Mlir-commits] [mlir] bbf4436 - [mlir][linalg] Remove the StructuredOp capture mechanism.

Tobias Gysi llvmlistbot at llvm.org
Mon Jun 28 00:59:09 PDT 2021


Author: Tobias Gysi
Date: 2021-06-28T07:57:40Z
New Revision: bbf4436a82febeab811af59b20d6928e694b4178

URL: https://github.com/llvm/llvm-project/commit/bbf4436a82febeab811af59b20d6928e694b4178
DIFF: https://github.com/llvm/llvm-project/commit/bbf4436a82febeab811af59b20d6928e694b4178.diff

LOG: [mlir][linalg] Remove the StructuredOp capture mechanism.

After https://reviews.llvm.org/D104109, structured ops support scalar inputs. As a result, the capture mechanism meant to pass non-shaped parameters got redundant. The patch removes the capture semantics after the FillOp migrated to use scalar operands https://reviews.llvm.org/D104121.

Differential Revision: https://reviews.llvm.org/D104785

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/Linalg.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Bindings/Python/DialectLinalg.cpp
    mlir/lib/CAPI/Dialect/Linalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 6e20eec16481a..27f2f7bc897f7 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -18,11 +18,9 @@ extern "C" {
 #endif
 
 /// Apply the special region builder for the builtin named Linalg op.
-/// The list of `capture` MlirValue is passed as-is to the region builder.
 /// Assert that `op` is a builtin named Linalg op.
 MLIR_CAPI_EXPORTED void
-mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
-                                   intptr_t n, MlirValue const *mlirCaptures);
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 9d1e3baad8ee7..092d22983d3f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -49,7 +49,7 @@ def Linalg_Dialect : Dialect {
       kInplaceableAttrName = "linalg.inplaceable";
 
     using RegionBuilderFunType =
-      llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &, ValueRange)>;
+      llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
       return namedStructuredOpRegionBuilders.lookup(name);
     }

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 8c0d4763376c3..ad91e23607141 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -901,7 +901,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ValueRange)>",
+      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f83f484187c97..18f5beeddf2ea 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -153,10 +153,8 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
     Value getSource() { return input();}
     Value getTarget() { return output(); }
 
-    static void regionBuilder(
-      ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
-    static std::function<
-      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
+    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
+    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
     getRegionBuilder() {
       return ®ionBuilder;
     }
@@ -200,10 +198,8 @@ def FillOp : LinalgStructured_Op<"fill", []> {
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
 
-    static void regionBuilder(
-      ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
-    static std::function<
-      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
+    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
+    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
     getRegionBuilder() {
       return ®ionBuilder;
     }
@@ -291,8 +287,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
       return padding().getValue().getValue<int64_t>({i, 1});
     }
 
-    static std::function<
-      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
+    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -533,8 +528,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
         library_call()->str() : "op_has_no_registered_library_name";
     }
 
-    static std::function<
-      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
+    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
     getRegionBuilder() {
       return nullptr;
     }

diff  --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index dfac96db74b12..a2a54249e6d68 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -21,15 +21,10 @@ using namespace mlir::python;
 void mlir::python::populateDialectLinalgSubmodule(py::module m) {
   m.def(
       "fill_builtin_region",
-      [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
-        llvm::SmallVector<MlirValue, 4> mlirOperands;
-        mlirOperands.reserve(captures.size());
-        for (auto v : captures)
-          mlirOperands.push_back(py::cast<PyValue *>(v)->get());
-        mlirLinalgFillBuiltinNamedOpRegion(
-            dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
+      [](PyDialectDescriptor &dialect, PyOperation &op) {
+        mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
       },
-      py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
+      py::arg("dialect"), py::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 }

diff  --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index be0d5448819d9..902599f3b9adf 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -16,13 +16,8 @@ using namespace mlir::linalg;
 /// Apply the special region builder for the builtin named Linalg op.
 /// Assert that `op` is a builtin named Linalg op.
 void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
-                                        MlirOperation mlirOp, intptr_t n,
-                                        MlirValue const *mlirCaptures) {
+                                        MlirOperation mlirOp) {
   Operation *op = unwrap(mlirOp);
-  SmallVector<Value> captures;
-  captures.reserve(n);
-  for (unsigned idx = 0; idx < n; ++idx)
-    captures.push_back(unwrap(mlirCaptures[idx]));
 
   LinalgDialect::RegionBuilderFunType fun =
       static_cast<LinalgDialect *>(unwrap(linalgDialect))
@@ -41,7 +36,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
   b.setInsertionPointToStart(body);
-  fun(b, *body, captures);
+  fun(b, *body);
 }
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 109a1c60ddc39..11cb3e15c0e0c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -43,20 +43,19 @@ using namespace mlir::linalg;
 /// defined C++ ops.
 /// This is used by both builders and parsers.
 /// This function creates the block in the region with arguments corresponding
-/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
-/// to be ShapedType.
+/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
+/// asserted to be of ShapedType.
 template <typename NamedStructuredOpType>
 static void fillStructuredOpRegion(
     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
-    TypeRange outputTypes, ValueRange captures = {},
+    TypeRange outputTypes,
     std::function<void(unsigned, unsigned)> errorHandler = nullptr);
 
 /// Generic entry point to create both the region and the block of a LinalgOp.
 template <typename NamedStructuredOpType>
 static void
 createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
-                                TypeRange inputTypes, TypeRange outputTypes,
-                                ValueRange captures = {});
+                                TypeRange inputTypes, TypeRange outputTypes);
 
 /// Common parsing and printing used for both named structured ops created by
 /// ods-gen and by manually defined C++ ops. Does not handle regions.
@@ -72,17 +71,15 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes,
-                             ArrayRef<OpAsmParser::OperandType> captures = {});
+                             TypeRange inputTypes, TypeRange outputTypes);
 
 static ParseResult
 parseNamedStructuredOpResults(OpAsmParser &parser,
                               SmallVectorImpl<Type> &resultTypes);
 
 template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
-                       ArrayRef<OpAsmParser::OperandType> captures = {});
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result);
 
 static void printNamedStructuredOpResults(OpAsmPrinter &p,
                                           TypeRange resultTypes);
@@ -323,8 +320,7 @@ class RegionBuilderHelper {
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
-void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                           ValueRange captures) {
+void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
   assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
   b.create<linalg::YieldOp>(block.getArgument(0));
 }
@@ -403,8 +399,7 @@ void CopyOp::getEffects(
 //===----------------------------------------------------------------------===//
 // FillOp
 //===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                           ValueRange captures) {
+void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
   assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
   b.create<linalg::YieldOp>(block.getArgument(0));
 }
@@ -2799,7 +2794,6 @@ template <typename NamedStructuredOpType>
 static void
 fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                        TypeRange inputTypes, TypeRange outputTypes,
-                       ValueRange captures,
                        std::function<void(unsigned, unsigned)> errorHandler) {
   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
 
@@ -2823,7 +2817,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   opBuilder.setInsertionPointToStart(body);
   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  NamedStructuredOpType::regionBuilder(b, *body, captures);
+  NamedStructuredOpType::regionBuilder(b, *body);
 
   // indexing_maps is an auto-generated method.
 
@@ -2835,11 +2829,10 @@ template <typename NamedStructuredOpType>
 void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
                                      OperationState &result,
                                      TypeRange inputTypes,
-                                     TypeRange outputTypes,
-                                     ValueRange captures) {
+                                     TypeRange outputTypes) {
   Region &region = *result.addRegion();
   fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes, captures,
+      opBuilder, region, inputTypes, outputTypes,
       [&](unsigned expected, unsigned actual) {
         assert(expected != actual && "incorrect number of arguments");
       });
@@ -2902,15 +2895,14 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes,
-                             ArrayRef<OpAsmParser::OperandType> captures) {
+                             TypeRange inputTypes, TypeRange outputTypes) {
   ParseResult res = success();
   OpBuilder opBuilder(parser.getBuilder().getContext());
   // Resolve `captures` into `capturedValues` at parse time so we can build the
   // region with captures.
   SmallVector<Value> capturedValues;
   fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes, capturedValues,
+      opBuilder, region, inputTypes, outputTypes,
       [&](unsigned expected, unsigned actual) {
         res = parser.emitError(
             parser.getCurrentLocation(),
@@ -2931,11 +2923,9 @@ parseNamedStructuredOpResults(OpAsmParser &parser,
 }
 
 template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
-                       ArrayRef<OpAsmParser::OperandType> captures) {
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result) {
   // TODO: Enable when ods-gen supports captures.
-  assert(captures.empty() && "unexpected captures for named structured ops");
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
@@ -2949,7 +2939,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
 
   std::unique_ptr<Region> region = std::make_unique<Region>();
   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
-          parser, *region, inputTypes, outputTypes, captures))
+          parser, *region, inputTypes, outputTypes))
     return failure();
   result.addRegion(std::move(region));
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index d5e619719fd7f..d0d14f86c54dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -63,8 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
       iterators,
       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
         ImplicitLocOpBuilder b(loc, bodyBuilder);
-        regionBuilder(b, *bodyBuilder.getBlock(),
-                      /*captures=*/{});
+        regionBuilder(b, *bodyBuilder.getBlock());
       });
 }
 

diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index c7ddfb962375d..bce4e08ae3a06 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -33,7 +33,7 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
         ip=ip)
     OpView.__init__(self, op)
     linalgDialect = Context.current.get_dialect_descriptor("linalg")
-    fill_builtin_region(linalgDialect, self.operation, [])
+    fill_builtin_region(linalgDialect, self.operation)
     # TODO: self.result is None. When len(results) == 1 we expect it to be
     # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
     # in the generator of _linalg_ops_gen.py where we have:

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 84adc8b260c49..471961f837bf3 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -24,7 +24,7 @@
 //  IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
 //
 //       IMPL:  void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
-//       IMPL:    Block &block, ValueRange captures) {
+//       IMPL:    Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@@ -49,7 +49,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //       IMPL:  AffineMap::get(3, 3, {d0, d1}, context)
 //
 //       IMPL:  Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
-//       IMPL:    Block &block, ValueRange captures) {
+//       IMPL:    Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@@ -74,7 +74,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //       IMPL:  AffineMap::get(4, 4, {d0, d1, d2}, context)
 //
 //       IMPL:  Test3Op::regionBuilder(ImplicitLocOpBuilder &b,
-//       IMPL:    Block &block, ValueRange captures) {
+//       IMPL:    Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@@ -182,7 +182,7 @@ def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
 
 // Test output arg order.
 // IMPL-LABEL:  void Test8Op::regionBuilder(ImplicitLocOpBuilder &b,
-//       IMPL:    Block &block, ValueRange captures) {
+//       IMPL:    Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = b.create<SubFOp>([[d]], [[c]]);
@@ -199,7 +199,7 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
 //       IMPL:    auto map1 = AffineMap::get(2, 2, {d1}, context);
 //       IMPL:    auto map2 = AffineMap::get(2, 2, {d0}, context);
 // IMPL-LABEL:  void Test9Op::regionBuilder(ImplicitLocOpBuilder &b,
-//       IMPL:    Block &block, ValueRange captures) {
+//       IMPL:    Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[c:.*]](args[2]);
 ods_def<Test9Op>:
 def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 471890e5f4a45..3c8b5271cf5c3 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -76,7 +76,7 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:      TypeRange(outputs)
 
 # IMPL-LABEL:  void Test1Op::regionBuilder(
-#       IMPL:    ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+#       IMPL:    ImplicitLocOpBuilder &b, Block &block)
 #       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
 #   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
@@ -163,8 +163,7 @@ structured_op: !LinalgStructuredOpConfig
 #       IMPL:  auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
 #       IMPL:  "missing indexing map required attribute 'strides'"
 
-#       IMPL:  void Test2Op::regionBuilder(
-#  IMPL-NEXT:    ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+#       IMPL:  void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
 #  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&
 
 #       IMPL:   yields.push_back(block.getArgument(0));

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index faa2835d589e7..1bdb5b8806d0d 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1923,7 +1923,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs)/*, TODO: support captures*/);
+            TypeRange(outputs));
         }]>,
         OpBuilder<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -1941,7 +1941,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs)/*, TODO: support captures*/);
+            TypeRange(outputs));
         }]>,
         OpBuilder<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -1956,7 +1956,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
       let parser = [{{
-        return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
+        return ::parseNamedStructuredOp<{0}>(parser, result);
       }];
       let hasFolder = 1;
 
@@ -1964,10 +1964,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         // Auto-generated.
         ArrayAttr iterator_types();
         ArrayAttr indexing_maps();
-        static void regionBuilder(ImplicitLocOpBuilder &b,
-                                  Block &block, ValueRange captures);
-        static std::function<void(ImplicitLocOpBuilder &b,
-                                  Block &, ValueRange)> getRegionBuilder() {{
+        static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
+        static std::function<void(ImplicitLocOpBuilder &b, Block &)>
+        getRegionBuilder() {{
           return regionBuilder;
         }
 
@@ -2035,7 +2034,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_builder,
           $_state,
           TypeRange(inputs),
-          TypeRange(outputs)/*, TODO: support captures*/);
+          TypeRange(outputs));
         {2}
       }]>
     )FMT";
@@ -2354,8 +2353,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
   };
 
   const char *regionBuilderFmt = R"FMT(
-  void {0}::regionBuilder(ImplicitLocOpBuilder &b,
-                          Block &block, ValueRange captures) {
+  void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
     auto args = block.getArguments();
     Value {1};
     {2}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 00c4096d095cf..83447f4930170 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -511,10 +511,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
       // Auto-generated.
       ArrayAttr iterator_types();
       ArrayAttr indexing_maps();
-      static void regionBuilder(
-        ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
-      static std::function<
-        void(ImplicitLocOpBuilder &b, Block &, ValueRange)>
+      static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
+      static std::function<void(ImplicitLocOpBuilder &b, Block &)>
       getRegionBuilder() {{
         return regionBuilder;
       }
@@ -883,8 +881,7 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
     // {1}: Number of args
     // {2}: Statements
     static const char structuredOpRegionBuilderFormat[] = R"FMT(
-void {0}::regionBuilder(
-    ImplicitLocOpBuilder &b, Block &block, ValueRange captures) {{
+void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
   assert({1} > 0 && block.getNumArguments() == {1} &&
          "{0} regionBuilder expects {1} (>=0) args");
   RegionBuilderHelper helper(block.getArgument(0).getContext(), block);


        


More information about the Mlir-commits mailing list