[Mlir-commits] [mlir] 973e133 - [mlir][Linalg] Improve region support in Linalg ops.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 12 06:54:38 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-12T14:51:03Z
New Revision: 973e133b769773c89ce4b8bbfd6c77612d2ff9d4

URL: https://github.com/llvm/llvm-project/commit/973e133b769773c89ce4b8bbfd6c77612d2ff9d4
DIFF: https://github.com/llvm/llvm-project/commit/973e133b769773c89ce4b8bbfd6c77612d2ff9d4.diff

LOG: [mlir][Linalg] Improve region support in Linalg ops.

This revision takes advantage of the newly extended `ref` directive in assembly format
to allow better region handling for LinalgOps. Specifically, FillOp and CopyOp now build their regions explicitly which allows retiring older behavior that relied on specific op knowledge in both lowering to loops and vectorization.

Differential Revision: https://reviews.llvm.org/D96598

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Transforms/copy-removal.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index c26f02208215f..95656ebd99835 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1056,20 +1056,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     //===------------------------------------------------------------------===//
     // Other static interface methods.
     //===------------------------------------------------------------------===//
-    StaticInterfaceMethod<
-      /*desc=*/[{
-        Create an operation of the current type with the given location,
-        operands, and attributes.
-      }],
-      /*retTy=*/"Operation *",
-      /*methodName=*/"create",
-      (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
-           "ValueRange":$operands,
-           "ArrayRef<NamedAttribute>":$attributes), [{
-        return builder.create<ConcreteOp>(
-          loc, resultTypes, operands, attributes);
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Clone the current operation with the given location and operands. This
@@ -1082,14 +1068,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
            "ValueRange":$operands),
       [{
-        BlockAndValueMapping map;
-        unsigned numRegions = $_op->getNumRegions();
-        Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
-        assert(res->getNumRegions() == numRegions && "inconsistent # regions");
-        for (unsigned ridx = 0; ridx < numRegions; ++ridx)
-          $_op->getRegion(ridx).cloneInto(
-            &res->getRegion(ridx), map);
-        return res;
+        BlockAndValueMapping bvm;
+        OperationState state(
+          loc, ConcreteOp::getOperationName(), operands, resultTypes,
+          $_op->getAttrs());
+        for (Region &r : $_op->getRegions())
+          r.cloneInto(state.addRegion(), bvm);
+        return b.createOperation(state);
       }]
     >,
     StaticInterfaceMethod<
@@ -1098,7 +1083,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(Block &)>",
+      /*retTy=*/"std::function<void(Block &, ValueRange)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8988a3a11efd8..05a6bb766dd0c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -110,14 +110,13 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
     AnyStridedMemRef:$output,
     OptionalAttr<AffineMapAttr>:$inputPermutation,
     OptionalAttr<AffineMapAttr>:$outputPermutation);
+  let regions = (region AnyRegion:$region);
 
-  // TODO: this should go away once the usage of OptionalAttr triggers emission
-  // of builders with default arguments left unspecified.
-  let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
-    [{
-      return build(
-        $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
-    }]>];
+  let builders = [
+    OpBuilderDAG<(ins "Value":$input, "Value":$output,
+      CArg<"AffineMap", "AffineMap()">:$inputPermutation,
+      CArg<"AffineMap", "AffineMap()">:$outputPermutation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
 
   let extraClassDeclaration = structuredOpsDecls # [{
     ValueRange inputs() { return getOperands().take_front(); }
@@ -146,24 +145,31 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
     Value getSource() { return input();}
     Value getTarget() { return output(); }
 
-    static std::function<void(Block &)> getRegionBuilder() {
-      return nullptr;
+    static void regionBuilder(Block &block, ValueRange captures);
+    static std::function<void(Block &block, ValueRange captures)>
+    getRegionBuilder() {
+      return ®ionBuilder;
     }
+    static unsigned getNumRegionArgs() { return 2; }
   }];
   let verifier = [{ return ::verify(*this); }];
 
   let assemblyFormat = [{
-    `(` operands `)` attr-dict `:` type(operands)
+    `(` $input `,` $output `)` attr-dict `:`
+        type($input) `,` type($output)
+      custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
   }];
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
+  let skipDefaultBuilders = 1;
 }
 
 def FillOp : LinalgStructured_Op<"fill", []> {
   let arguments = (ins AnyShaped:$output,
                    AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
   let results = (outs Optional<AnyRankedTensor>:$result);
+  let regions = (region AnyRegion:$region);
   let extraClassDeclaration = structuredOpsDecls # [{
     ValueRange inputs() { return {}; }
     ValueRange outputs() { return getOperands().take_front(); }
@@ -183,13 +189,18 @@ def FillOp : LinalgStructured_Op<"fill", []> {
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
 
-    static std::function<void(Block &)> getRegionBuilder() {
-      return nullptr;
+    static void regionBuilder(Block &block, ValueRange captures);
+    static std::function<void(Block &block, ValueRange captures)>
+    getRegionBuilder() {
+      return ®ionBuilder;
     }
+    static unsigned getNumRegionArgs() { return 1; }
   }];
 
   let assemblyFormat = [{
-    `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
+    `(` $output `,` $value `)` attr-dict `:`
+        type($output) `,` type($value) (`->` type($result)^)?
+      custom<FillOpRegion>($region, ref(type($output)), ref($value))
   }];
 
   let builders = [
@@ -268,7 +279,8 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
       return padding().getValue().getValue<int64_t>({i, 1});
     }
 
-    static std::function<void(Block &)> getRegionBuilder() {
+    static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
+    {
       return nullptr;
     }
   }];
@@ -519,7 +531,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
         library_call()->str() : "op_has_no_registered_library_name";
     }
 
-    static std::function<void(Block &)> getRegionBuilder() {
+    static std::function<void(Block &, ValueRange)> getRegionBuilder() {
       return nullptr;
     }
   }];

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 8b53ecb740756..276b124a9f10f 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -154,7 +154,13 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
   if (in == op.input() && out == op.output())
     return failure();
 
-  rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+  if (!libraryCallName)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<mlir::CallOp>(
+      op, libraryCallName.getValue(), TypeRange(),
+      createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 8bb104df5bd61..46e42e26c5c8c 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -27,8 +27,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
     ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
     function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
     ArrayRef<Attribute> otherAttributes) {
-  OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
-
   // Build maps
   SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
   exprsList.reserve(inputs.size() + outputs.size());
@@ -54,13 +52,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
               resultTensorTypes,
               inputValues,
               outputValues,
-              builder.getAffineMapArrayAttr(maps),
-              builder.getStrArrayAttr(iteratorStrTypes),
-              StringAttr() /*doc*/,
-              StringAttr() /*library_call*/,
-              ArrayAttr() /*sparse*/
-              /* TODO: other attributes in op */
-              )
+              maps,
+              iteratorStrTypes,
+              ""/*doc*/,
+              ""/*library_call*/)
           .getOperation();
   // clang-format on
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 49e3c2b87ae79..989a164007d5e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -33,32 +33,53 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 /// Forward declarations.
+
+/// Generic entry point to create the block for the region of a LinalgOp.
+/// This is used by both named structured ops created by ods-gen and by manually
+/// defined C++ ops.
+/// This is used by both builders and parsers.
+/// This function creates the block in the region with arguments corresponding
+/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
+/// to be ShapedType.
+template <typename NamedStructuredOpType>
+static void fillStructuredOpRegion(
+    OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
+    TypeRange outputTypes, ValueRange captures = {},
+    std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
+                                                              unsigned) {});
+
+/// Generic entry point to create both the region and the block of a LinalgOp.
 template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
-                                                      OperationState &result,
-                                                      TypeRange inputTypes,
-                                                      TypeRange outputTypes);
+static void
+createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
+                                TypeRange inputTypes, TypeRange outputTypes,
+                                ValueRange captures = {});
 
+/// Common parsing and printing used for both named structured ops created by
+/// ods-gen and by manually defined C++ ops. Does not handle regions.
 static ParseResult
 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              SmallVectorImpl<Type> &inputTypes,
                              SmallVectorImpl<Type> &outputTypes);
+template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+                                         NamedStructuredOpType op);
 
+/// Specific parsing and printing for named structured ops created by ods-gen.
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes);
+                             TypeRange inputTypes, TypeRange outputTypes,
+                             ArrayRef<OpAsmParser::OperandType> captures = {});
+
 static ParseResult
 parseNamedStructuredOpResults(OpAsmParser &parser,
                               SmallVectorImpl<Type> &resultTypes);
 
 template <typename NamedStructuredOpType>
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
-                                          OperationState &result);
-
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
-                                         NamedStructuredOpType op);
+static ParseResult
+parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
+                       ArrayRef<OpAsmParser::OperandType> captures = {});
 
 static void printNamedStructuredOpResults(OpAsmPrinter &p,
                                           TypeRange resultTypes);
@@ -83,14 +104,136 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+//===----------------------------------------------------------------------===//
+// CopyOp
+//===----------------------------------------------------------------------===//
+void CopyOp::regionBuilder(Block &block, ValueRange captures) {
+  using namespace edsc::intrinsics;
+  assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
+  (linalg_yield(block.getArgument(0)));
+}
+
+void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
+                   Value output, AffineMap inputPermutation,
+                   AffineMap outputPermutation,
+                   ArrayRef<NamedAttribute> namedAttrs) {
+  result.addOperands({input, output});
+  result.addAttributes(namedAttrs);
+  if (inputPermutation)
+    result.addAttribute("inputPermutation",
+                        AffineMapAttr::get(inputPermutation));
+  if (outputPermutation)
+    result.addAttribute("outputPermutation",
+                        AffineMapAttr::get(outputPermutation));
+  result.addRegion();
+  fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
+                                 TypeRange{input.getType()},
+                                 TypeRange{output.getType()});
+}
+
+ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
+                              Type outputType) {
+  OpBuilder opBuilder(parser.getBuilder().getContext());
+  fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
+                                 TypeRange{outputType});
+  return success();
+}
+
+/// CopyOp region is elided when printing.
+void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
+
+static LogicalResult verify(CopyOp op) {
+  auto outputViewType = op.getOutputShapedType(0);
+  auto inputViewType = op.getInputShapedType(0);
+  if (inputViewType.getElementType() != outputViewType.getElementType())
+    return op.emitOpError("expects views of the same type");
+  if (inputViewType.getRank() != outputViewType.getRank())
+    return op.emitOpError("expects views of the same rank");
+  auto rank = op.getNumParallelLoops();
+  auto inputPermutationMap = op.inputPermutation();
+  if (inputPermutationMap) {
+    if (inputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional input_permutation map of rank ")
+             << rank;
+    if (!inputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional input_permutation map to be a permutation");
+  }
+  auto outputPermutationMap = op.outputPermutation();
+  if (outputPermutationMap) {
+    if (outputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional output_permutation map of rank ")
+             << rank;
+    if (!outputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional output_permutation map to be a permutation");
+  }
+  if (rank == 0 && inputPermutationMap)
+    return op.emitOpError("expected no input permutation when rank == 0");
+  if (rank == 0 && outputPermutationMap)
+    return op.emitOpError("expected no output permutation when rank == 0");
+  return success();
+}
+
+void CopyOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), input(),
+                       SideEffects::DefaultResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), output(),
+                       SideEffects::DefaultResource::get());
+}
+
 //===----------------------------------------------------------------------===//
 // FillOp
 //===----------------------------------------------------------------------===//
+void FillOp::regionBuilder(Block &block, ValueRange captures) {
+  using namespace edsc::intrinsics;
+  assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
+  (linalg_yield(captures));
+}
 
 void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
                    Value value) {
   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
         value);
+  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
+                                 TypeRange{output.getType()}, value);
+}
+
+ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
+                              OpAsmParser::OperandType valueRef) {
+  OpBuilder opBuilder(parser.getBuilder().getContext());
+  // Resolve `valueRef` into `value` at parse time so we can build the region
+  // with captures.
+  SmallVector<Value> value;
+  parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
+  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
+                                 TypeRange{outputType}, value);
+  return success();
+}
+
+/// FillOp region is elided when printing.
+void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
+
+static LogicalResult verify(FillOp op) {
+  auto viewType = op.getOutputShapedType(0);
+  auto fillType = op.value().getType();
+  if (viewType.getElementType() != fillType)
+    return op.emitOpError("expects fill type to match view elemental type");
+  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+    return op.emitOpError(
+        "expected fill op with no result value to use memref type");
+  }
+  return success();
+}
+
+void FillOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (output().getType().isa<MemRefType>())
+    effects.emplace_back(MemoryEffects::Write::get(), output(),
+                         SideEffects::DefaultResource::get());
 }
 
 //===----------------------------------------------------------------------===//
@@ -397,7 +540,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 // InitTensorOp
 //===----------------------------------------------------------------------===//
 
-
 static LogicalResult verify(InitTensorOp op) {
   RankedTensorType resultType = op.getType();
   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
@@ -1396,68 +1538,6 @@ static LogicalResult verify(linalg::YieldOp op) {
 
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
-void FillOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  if (output().getType().isa<MemRefType>())
-    effects.emplace_back(MemoryEffects::Write::get(), output(),
-                         SideEffects::DefaultResource::get());
-}
-
-static LogicalResult verify(FillOp op) {
-  auto viewType = op.getOutputShapedType(0);
-  auto fillType = op.value().getType();
-  if (viewType.getElementType() != fillType)
-    return op.emitOpError("expects fill type to match view elemental type");
-  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
-    return op.emitOpError(
-        "expected fill op with no result value to use memref type");
-  }
-  return success();
-}
-
-void CopyOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), input(),
-                       SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), output(),
-                       SideEffects::DefaultResource::get());
-}
-
-static LogicalResult verify(CopyOp op) {
-  auto outputViewType = op.getOutputShapedType(0);
-  auto inputViewType = op.getInputShapedType(0);
-  if (inputViewType.getElementType() != outputViewType.getElementType())
-    return op.emitOpError("expects views of the same type");
-  if (inputViewType.getRank() != outputViewType.getRank())
-    return op.emitOpError("expects views of the same rank");
-  auto rank = op.getNumParallelLoops();
-  auto inputPermutationMap = op.inputPermutation();
-  if (inputPermutationMap) {
-    if (inputPermutationMap->getNumInputs() != rank)
-      return op.emitOpError("expects optional input_permutation map of rank ")
-             << rank;
-    if (!inputPermutationMap->isPermutation())
-      return op.emitOpError(
-          "expects optional input_permutation map to be a permutation");
-  }
-  auto outputPermutationMap = op.outputPermutation();
-  if (outputPermutationMap) {
-    if (outputPermutationMap->getNumInputs() != rank)
-      return op.emitOpError("expects optional output_permutation map of rank ")
-             << rank;
-    if (!outputPermutationMap->isPermutation())
-      return op.emitOpError(
-          "expects optional output_permutation map to be a permutation");
-  }
-  if (rank == 0 && inputPermutationMap)
-    return op.emitOpError("expected no input permutation when rank == 0");
-  if (rank == 0 && outputPermutationMap)
-    return op.emitOpError("expected no output permutation when rank == 0");
-  return success();
-}
-
 template <typename LinalgPoolingOp>
 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
                                             ArrayRef<Attribute> attrs,
@@ -1690,14 +1770,25 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
-// Auto-generated Linalg named ops.
+// Support for named Linalg ops defined in ods-gen.
 //===----------------------------------------------------------------------===//
 
+/// Generic entry point to create the block for the region of a LinalgOp.
+/// This is used by both named structured ops created by ods-gen and by manually
+/// defined C++ ops.
+/// This is used by both builders and parsers.
+/// This function creates the block in the region with arguments corresponding
+/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
+/// to be ShapedType.
 template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegionAndAttributesImpl(
-    OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
-    TypeRange outputTypes,
-    std::function<void(unsigned, unsigned)> errorHandler) {
+static void
+fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
+                       TypeRange inputTypes, TypeRange outputTypes,
+                       ValueRange captures,
+                       std::function<void(unsigned, unsigned)> errorHandler) {
+  assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
+  assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
+
   // TODO: atm all operands go through getElementTypeOrSelf,
   // reconsider when we have evidence we need to.
   SmallVector<Type, 8> argTypes;
@@ -1707,7 +1798,7 @@ static void buildNamedStructuredOpRegionAndAttributesImpl(
 
   // RAII.
   OpBuilder::InsertionGuard guard(opBuilder);
-  Block *body = opBuilder.createBlock(&region, {}, argTypes);
+  Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
   unsigned actual = body->getNumArguments();
   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
   if (expected != actual)
@@ -1715,53 +1806,30 @@ static void buildNamedStructuredOpRegionAndAttributesImpl(
 
   opBuilder.setInsertionPointToStart(body);
   mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
-  NamedStructuredOpType::regionBuilder(*body);
+  NamedStructuredOpType::regionBuilder(*body, captures);
 
   // indexing_maps is an auto-generated method.
 
   // iterator_types is an auto-generated method.
 }
 
+/// Generic entry point to create both the region and the block of a LinalgOp.
 template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
-                                               OperationState &result,
-                                               TypeRange inputTypes,
-                                               TypeRange outputTypes) {
+void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
+                                     OperationState &result,
+                                     TypeRange inputTypes,
+                                     TypeRange outputTypes,
+                                     ValueRange captures) {
   Region &region = *result.addRegion();
-  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes,
+  fillStructuredOpRegion<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputTypes, captures,
       [&](unsigned expected, unsigned actual) {
-        llvm::errs() << "region expects " << expected << " args, got "
-                     << actual;
         assert(expected != actual && "incorrect number of arguments");
       });
 }
 
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes) {
-  ParseResult res = success();
-  OpBuilder opBuilder(parser.getBuilder().getContext());
-  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes,
-      [&](unsigned expected, unsigned actual) {
-        res = parser.emitError(parser.getCurrentLocation(),
-                               llvm::formatv("region expects {0} args, got {1}",
-                                             expected, actual));
-      });
-  return res;
-}
-
-static ParseResult
-parseNamedStructuredOpResults(OpAsmParser &parser,
-                              SmallVectorImpl<Type> &resultTypes) {
-  if (succeeded(parser.parseOptionalArrow()))
-    if (parser.parseTypeList(resultTypes))
-      return failure();
-  return success();
-}
-
+/// Common parsing used for both named structured ops created by ods-gen and by
+/// manually defined C++ ops. Does not handle regions.
 static ParseResult
 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              SmallVectorImpl<Type> &inputTypes,
@@ -1802,8 +1870,56 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
 }
 
 template <typename NamedStructuredOpType>
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
-                                          OperationState &result) {
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+                                         NamedStructuredOpType op) {
+  if (!op.inputs().empty())
+    p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+  if (!op.outputs().empty())
+    p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// Specific parsing and printing for named structured ops created by ods-gen.
+//===----------------------------------------------------------------------===//
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
+                             TypeRange inputTypes, TypeRange outputTypes,
+                             ArrayRef<OpAsmParser::OperandType> captures) {
+  ParseResult res = success();
+  OpBuilder opBuilder(parser.getBuilder().getContext());
+  // Resolve `captures` into `capturedValues` at parse time so we can build the
+  // region with captures.
+  SmallVector<Value> capturedValues;
+  fillStructuredOpRegion<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputTypes, capturedValues,
+      [&](unsigned expected, unsigned actual) {
+        res = parser.emitError(
+            parser.getCurrentLocation(),
+            llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
+                          "region expects {0} args, got {1}",
+                          expected, actual));
+        region.front().dump();
+      });
+  return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+                              SmallVectorImpl<Type> &resultTypes) {
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(resultTypes))
+      return failure();
+  return success();
+}
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
+                       ArrayRef<OpAsmParser::OperandType> captures) {
+  // TODO: Enable when ods-gen supports captures.
+  assert(captures.empty() && "unexpected captures for named structured ops");
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
@@ -1817,7 +1933,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
 
   std::unique_ptr<Region> region = std::make_unique<Region>();
   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
-          parser, *region, inputTypes, outputTypes))
+          parser, *region, inputTypes, outputTypes, captures))
     return failure();
   result.addRegion(std::move(region));
 
@@ -1831,15 +1947,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
   p.printOptionalArrowTypeList(resultTypes);
 }
 
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
-                                         NamedStructuredOpType op) {
-  if (!op.inputs().empty())
-    p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
-  if (!op.outputs().empty())
-    p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
-}
-
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
   p << op.getOperationName();
@@ -1861,6 +1968,10 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
   return verifyGenericOp<NamedStructuredOpType>(op);
 }
 
+//===----------------------------------------------------------------------===//
+// Canonicalizers and Folders.
+//===----------------------------------------------------------------------===//
+
 namespace {
 struct EraseDeadLinalgOp : public RewritePattern {
   EraseDeadLinalgOp(PatternBenefit benefit = 1)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 0be1c55c1ea78..69de55c00cb79 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
       indexingMaps, iterators,
       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
         edsc::ScopedContext scope(bodyBuilder, loc);
-        regionBuilder(*bodyBuilder.getBlock());
+        regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
       });
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 391562b032a76..d09d3e0b5edd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -52,14 +52,6 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
   return res;
 }
 
-static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
-                                        Optional<AffineMap> permutation) {
-  return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
-                                        ScopedContext::getLocation(),
-                                        permutation.getValue(), ivs)
-                     : SmallVector<Value, 4>(ivs.begin(), ivs.end());
-}
-
 template <typename IndexedValueType, typename OpType>
 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
                                      ArrayRef<SmallVector<Value, 8>> indexing,
@@ -178,40 +170,6 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
                                              outputBuffers);
 }
 
-template <typename IndexedValueType>
-static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
-  assert(copyOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  auto nPar = copyOp.getNumParallelLoops();
-  assert(nPar == allIvs.size());
-  auto inputIvs =
-      permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
-  auto outputIvs =
-      permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
-  SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
-  SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
-  IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
-  // Emit the proper scalar assignment, whether we are dealing with a 0-D or
-  // an n-D loop nest; with or without permutations.
-  // clang-format off
-    nPar > 0 ? O(oivs) = I(iivs) :
-               O() = I();
-  // clang-format on
-}
-
-template <typename IndexedValueType>
-static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
-  assert(fillOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  auto nPar = fillOp.getNumParallelLoops();
-  assert(nPar == allIvs.size());
-  auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
-  IndexedValueType O(fillOp.getOutputBuffer(0));
-  // Emit the proper scalar assignment, whether we are dealing with a 0-D or
-  // an n-D loop nest; with or without permutations.
-  nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
-}
-
 // Create a padded view into the given `input` tensor using the 'indices'
 // to access the tensor. `skipPadding` lists the dimensions for which no padding
 // is needed e.g. the non-spatial dimensions for convolutions.
@@ -533,8 +491,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
         llvm::TypeSwitch<Operation *>(op)
-            .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
-                  PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
+            .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
+                  IndexedGenericOp, LinalgOp>([&](auto op) {
               emitScalarImplementation<IndexedValueTy>(allIvs, op);
             })
             .Default([&](Operation *op) { assert(false && "unexpected op"); });

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 49d323aebe92c..bfd288464c685 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
         llvm::map_range(linalgOp.getShapedOperandTypes(),
                         [](ShapedType t) { return t.getElementType(); }));
     block->addArguments(elementTypes);
-    linalgOp.getRegionBuilder()(*block);
+    linalgOp.getRegionBuilder()(*block, /*captures=*/{});
   }
   Block *block = &region->front();
 
@@ -333,24 +333,26 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
 
 // Return true if the op is an element-wise linalg op.
 static bool isElementwise(Operation *op) {
-  auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp)
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
     return false;
-  if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
+  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
     return false;
   // TODO: relax the restrictions on indexing map.
-  for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
-    if (!genericOp.getOutputIndexingMap(i).isIdentity())
+  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
+    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
       return false;
   }
   // Currently bound the input indexing map to minor identity as other
   // permutations might require adding transpose ops to convert the vector read
   // to the right shape.
-  for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
-    if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
+  for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
+    if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
       return false;
   }
-  return hasOnlyScalarElementwiseOp(genericOp.getRegion());
+  if (linalgOp->getNumRegions() != 1)
+    return false;
+  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
 }
 
 static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
@@ -393,9 +395,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
-
-  if (isa<linalg::FillOp, linalg::CopyOp>(op))
-    return success();
   if (isElementwise(op))
     return success();
   return success(isaContractionOpInterface(linalgOp));
@@ -407,43 +406,12 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
     return llvm::None;
 
   edsc::ScopedContext scope(builder, op->getLoc());
-  // In the case of 0-D memrefs, return null and special case to scalar load or
-  // store later.
-  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
-    // Vectorize fill as a vector.broadcast.
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
-    VectorizedLinalgOp res;
-    if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
-      res.tensorResults.push_back(v);
-    return res;
-  }
-  if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
-    // Vectorize copy as a vector.transfer_read+vector.transfer_write.
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                      << "Rewrite linalg.copy as vector.transfer_read + "
-                         "vector.transfer_write: "
-                      << *op);
-    Value vector = buildVectorRead(builder, copyOp.input());
-    VectorizedLinalgOp res;
-    if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
-      res.tensorResults.push_back(v);
-    return res;
-  }
   if (isElementwise(op)) {
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Vectorize linalg op as a generic: " << *op);
     return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
   }
 
-  // TODO: as soon as Copy and FillOp. get a region builder, replace all the
-  // above by:
-  // if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
-  //   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-  //                     << "Vectorize linalg op as a generic: " << *op);
-  //   return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
-  // }
-
   return vectorizeContraction(builder, cast<LinalgOp>(op));
 }
 

diff  --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir
index a66006fd0af8a..1432037f2110e 100644
--- a/mlir/test/Transforms/copy-removal.mlir
+++ b/mlir/test/Transforms/copy-removal.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt -copy-removal -split-input-file %s
-//| FileCheck %s
+// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
 
 // All linalg copies except the linalg.copy(%1, %9) must be removed since the
 // defining operation of %1 and its DeallocOp have been defined in another block.
@@ -256,7 +255,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
     %tmp2 = math.exp %gen2_arg0 : f32
     linalg.yield %tmp2 : f32
   }
-  "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
+  linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32>
   dealloc %temp : memref<2xf32>
   // CHECK: return
   return
@@ -292,7 +291,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
     linalg.yield %tmp2 : f32
   }
   // CHECK: linalg.copy
-  "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
+  linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32>
   dealloc %temp : memref<2xf32>
   return
 }
@@ -355,7 +354,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
   }
   // CHECK-NOT: linalg.copy
   // CHECK-NOT: dealloc
-  "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
+  linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32>
   dealloc %0 : memref<4xf32>
   //CHECK: return
   return

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index a16a2b85a9ec1..b197ba3da65df 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -23,7 +23,7 @@
 //  IMPL-NEXT: map2 = simplifyAffineMap(map2);
 //  IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
 //
-//       IMPL:  void Test1Op::regionBuilder(Block &block) {
+//       IMPL:  void Test1Op::regionBuilder(Block &block, ValueRange captures) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //       IMPL:  AffineMap::get(3, 3, {d2, d1}, context)
 //       IMPL:  AffineMap::get(3, 3, {d0, d1}, context)
 //
-//       IMPL:  Test2Op::regionBuilder(Block &block) {
+//       IMPL:  Test2Op::regionBuilder(Block &block, ValueRange captures) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //       IMPL:  AffineMap::get(4, 4, {d3, d2}, context)
 //       IMPL:  AffineMap::get(4, 4, {d0, d1, d2}, context)
 //
-//       IMPL:  Test3Op::regionBuilder(Block &block) {
+//       IMPL:  Test3Op::regionBuilder(Block &block, ValueRange captures) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 0934967f516c0..4f57322c8be69 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
-          buildNamedStructuredOpRegionAndAttributes<{0}>(
+          createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs));
+            TypeRange(outputs)/*, TODO: support captures*/);
         }]>,
         OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
-          buildNamedStructuredOpRegionAndAttributes<{0}>(
+          createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs));
+            TypeRange(outputs)/*, TODO: support captures*/);
         }]>,
         OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -1907,7 +1907,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         {6}
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
-      let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
+      let parser = [{{
+        return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
+      }];
       let hasFolder = 1;
       let hasCanonicalizer = 1;
 
@@ -1915,8 +1917,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         // Auto-generated.
         ArrayAttr iterator_types();
         ArrayAttr indexing_maps();
-        static void regionBuilder(Block &block);
-        static std::function<void(Block &)> getRegionBuilder() {{
+        static void regionBuilder(Block &block, ValueRange captures);
+        static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
           return regionBuilder;
         }
 
@@ -1980,11 +1982,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
-        buildNamedStructuredOpRegionAndAttributes<{0}>(
+        createAndFillStructuredOpRegion<{0}>(
           $_builder,
           $_state,
           TypeRange(inputs),
-          TypeRange(outputs));
+          TypeRange(outputs)/*, TODO: support captures*/);
         {2}
       }]>
     )FMT";
@@ -2311,7 +2313,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
   };
 
   const char *regionBuilderFmt = R"FMT(
-  void {0}::regionBuilder(Block &block) {
+  void {0}::regionBuilder(Block &block, ValueRange captures) {
     using namespace edsc;
     using namespace intrinsics;
     auto args = block.getArguments();


        


More information about the Mlir-commits mailing list