[Mlir-commits] [mlir] 538ac26 - [mlir][Linalg] Create a named batch_matmul op and pipe it through.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Apr 21 09:26:54 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-21T12:09:46-04:00
New Revision: 538ac26f25d98d682c1b31821915b14125180a93

URL: https://github.com/llvm/llvm-project/commit/538ac26f25d98d682c1b31821915b14125180a93
DIFF: https://github.com/llvm/llvm-project/commit/538ac26f25d98d682c1b31821915b14125180a93.diff

LOG: [mlir][Linalg] Create a named batch_matmul op and pipe it through.

This revision is the first in a set of improvements that aim at allowing
more generalized named Linalg op generation from a mathematical
specification.

This revision allows creating a new op and checks that the parser,
printer and verifier are hooked up properly.

This opened up a few design points that will be addressed in the future:
1. A named linalg op has a static region builder instead of an
explicitly parsed region. This is not currently compatible with
assemblyFormat so a custom parser / printer are needed.
2. The convention for structured ops and tensor return values needs to
evolve to allow tensor-land and buffer land specifications to agree
3. ReferenceIndexingMaps and referenceIterators will need to become
static to allow building attributes at parse time.
4. Error messages will be improved once we have 3. and we pretty print
in custom form.

Differential Revision: https://reviews.llvm.org/D78327

Added: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 41035ed76bba..6f25c2049272 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,11 +1,46 @@
+# Declare a function to generate ODS with mlir-linalg-ods-gen
+function(add_linalg_ods_gen tc_filename output_file)
+  set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename})
+  set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td)
+  set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc)
+  set_source_files_properties(
+    ${GEN_ODS_FILE}
+    PROPERTIES GENERATED TRUE)
+  set_source_files_properties(
+    ${GEN_CPP_FILE}
+    PROPERTIES GENERATED TRUE)
+  add_custom_command(
+    OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE}
+    COMMAND mlir-linalg-ods-gen -gen-ods-decl ${TC_SOURCE} > ${GEN_ODS_FILE}
+    COMMAND mlir-linalg-ods-gen -gen-impl ${TC_SOURCE} > ${GEN_CPP_FILE}
+    MAIN_DEPENDENCY
+    ${TC_SOURCE}
+    DEPENDS
+    mlir-linalg-ods-gen
+    VERBATIM)
+  add_custom_target(
+    MLIR${output_file}IncGen
+    DEPENDS
+    mlir-linalg-ods-gen
+    ${GEN_ODS_FILE} ${GEN_CPP_FILE})
+endfunction()
+
+add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
+# Provide a short name for all external dependency that needs to
+# include Linalg in ODS
+add_custom_target(LinalgOdsGen DEPENDS MLIRLinalgNamedStructuredOpsIncGen)
+
 add_mlir_dialect(LinalgOps linalg)
+
 add_mlir_doc(LinalgDoc -gen-op-doc LinalgOps Dialects/)
+add_dependencies(LinalgOpsDocGen LinalgOdsGen)
 
 set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
 mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
 mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)
 add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
-
+add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
+    
 set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td)
 mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs)

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
new file mode 100644
index 000000000000..9f9b53a22011
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -0,0 +1,4 @@
+ods_def<BatchMatmulOp>:
+def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
+  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
+}

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 61d909139f1b..e7b11df3141a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -516,7 +516,8 @@ class LinalgOperandOfRank<int rank>: Type<
     CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
   >>;
 
-class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
+class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
+    [SingleBlockImplicitTerminator<"YieldOp">]> {
   let arguments = (ins Variadic<LinalgOperand>:$views,
                    I64Attr:$args_in,
                    I64Attr:$args_out,
@@ -806,11 +807,22 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
 def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
 
 class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
-  : Op<Linalg_Dialect, mnemonic,
-       !listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
+    : LinalgStructuredBase_Op<mnemonic, props> {
   string spec = ?;
-  let assemblyFormat = "`(` operands `)` attr-dict `:` "
-    "functional-type(operands, results)";
+  // We cannot use an assemblyFormat atm because we need to hook in a custom-
+  // built implicit region from a static OpClass method.
+  // TODO: Revisit in the future if/when appropriate.
+  // let assemblyFormat = "`(` operands `)` attr-dict `:` "
+  //   "functional-type(operands, results)";
+
+  // The parser needs to specialize on the OpType so it has to be auto-generated
+  // in the linalg-ods tool.
+  let printer = [{ return ::printNamedStructuredOp(p, *this); }];
+  let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
+  let hasFolder = 1;
 }
 
+// This file is auto-generated from a tc specification.
+include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
+
 #endif // LINALG_STRUCTURED_OPS

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 46fb9881aba5..beac1135a0bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -64,7 +64,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       "Operation::operand_range", "getInputs"
     >,
     InterfaceMethod<[{
-        Return the type of the input shape at the given index.
+        Return the `i`-th input shaped type, irrespective of buffer or tensor
+        type.
       }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
     InterfaceMethod<[{
         Return the subset of input operands that are of ranked tensor type.
@@ -89,6 +90,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     InterfaceMethod<[{
         Return the type of the output buffer at the given index.
       }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
+    InterfaceMethod<[{
+        Return the `i`-th output shaped type, irrespective of buffer or tensor
+        type.
+      }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
     InterfaceMethod<[{
         Return the results that are of ranked tensor type.
       }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index b13b6d268226..1c427faff693 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/LLVM.h"
@@ -119,7 +120,8 @@ class StructuredOpTraits
       return it - getInputs().begin();
     return llvm::None;
   }
-  /// Return the `i`-th input buffer type.
+  /// Return the `i`-th input shaped type, irrespective of buffer or tensor
+  /// type.
   ShapedType getInputShapedType(unsigned i) {
     return getInput(i).getType().template cast<ShapedType>();
   }
@@ -344,6 +346,17 @@ class StructuredOpTraits
   }
 };
 
+/// This class provides the API for named Linalg StructuredOps.
+template <typename ConcreteType>
+class NamedStructuredOpTraits
+    : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
+public:
+  llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
+  llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
+  std::function<void(OpBuilder &, Location, ArrayRef<Value>)>
+  emitScalarImplementation();
+};
+
 } // namespace linalg
 } // namespace OpTrait
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
index f87938c943ef..932a213980f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,3 +1,8 @@
 set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td)
+mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
 mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters)
 add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen)
+
+# Including Linalg in TableGen requires to depends on generated files
+add_dependencies(MLIRLinalgTransformPatternsIncGen LinalgOdsGen)
+

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9f664586453a..8fa90f444f63 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExpr.h"
@@ -30,6 +31,20 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+/// Forward declarations.
+template <typename NamedStructuredOpType>
+static void buildNamedStructuredOpRegion(Builder &builder,
+                                         OperationState &result,
+                                         TypeRange operandTypes,
+                                         TypeRange tensorResultTypes);
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+template <typename NamedStructuredOpType>
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result);
+template <typename NamedStructuredOpType>
+static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
+
 /// Determines whether it is possible to fold it away in the parent Linalg op:
 ///
 /// ```mlir
@@ -184,7 +199,14 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
                                 parser.getCurrentLocation(), result.operands);
 }
 
-LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
+template <typename GenericOpType>
+struct BlockArgsVerifier {
+  static LogicalResult verify(GenericOpType op, Block &block);
+};
+
+template <typename GenericOpType>
+LogicalResult BlockArgsVerifier<GenericOpType>::verify(GenericOpType op,
+                                                       Block &block) {
   auto nOperands = op.getNumOperands();
   if (block.getNumArguments() != nOperands)
     return op.emitOpError("expected number of block arguments to match number "
@@ -203,7 +225,9 @@ LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
   return success();
 }
 
-LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
+template <>
+LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op,
+                                                          Block &block) {
   auto nInputViews = op.getNumInputs();
   auto nLoops = op.getNumLoops();
   auto nOperands = op.getNumOperands();
@@ -245,7 +269,7 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
   auto &region = op.region();
   if (region.getBlocks().size() != 1)
     return op.emitOpError("expected region with 1 block");
-  if (failed(verifyBlockArgs(op, region.getBlocks().front())))
+  if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
     return failure();
 
   SmallVector<AffineMap, 4> indexingMaps;
@@ -737,17 +761,18 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
                  parser.resolveOperands(opInfo, types, loc, result.operands));
 }
 
-template <typename GenericOpType>
-static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
-  // The operand number and types must match the view element types.
-  auto nOutputs = genericOp.getNumOutputs();
+// Check the operand number and types must match the element types of the
+// LinalgOp interface's shaped operands.
+static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
+  auto nOutputs = linalgOpInterface.getNumOutputs();
   if (op.getNumOperands() != nOutputs)
     return op.emitOpError("expected number of yield values (")
            << nOutputs << ") to match the number of operands of the enclosing "
-           << "linalg.generic op (" << op.getNumOperands() << ")";
+           << "LinalgOp (" << op.getNumOperands() << ")";
 
   for (unsigned i = 0; i != nOutputs; ++i) {
-    auto elementType = genericOp.getOutputShapedType(i).getElementType();
+    auto elementType =
+        linalgOpInterface.getOutputShapedType(i).getElementType();
     if (op.getOperand(i).getType() != elementType)
       return op.emitOpError("type of yield operand ")
              << (i + 1) << " (" << op.getOperand(i).getType()
@@ -763,17 +788,10 @@ static LogicalResult verify(YieldOp op) {
   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
     return op.emitOpError("expected single non-empty parent region");
 
-  auto genericOp = dyn_cast<GenericOp>(parentOp);
-  if (genericOp)
-    return verifyYield(op, genericOp);
-
-  auto indexedGenericOp = dyn_cast<IndexedGenericOp>(parentOp);
-  if (indexedGenericOp)
-    return verifyYield(op, indexedGenericOp);
+  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
+    return verifyYield(op, cast<LinalgOp>(parentOp));
 
-  return op.emitOpError("expected '")
-         << GenericOp::getOperationName() << "' or '"
-         << IndexedGenericOp::getOperationName() << "' parent op";
+  return op.emitOpError("expected parent op with LinalgOp interface");
 }
 
 /////// Operations corresponding to library calls defined with Tablegen ////////
@@ -1056,3 +1074,82 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
     return getResult();
   return {};
 }
+
+//===----------------------------------------------------------------------===//
+// Auto-generated Linalg named ops.
+//===----------------------------------------------------------------------===//
+
+template <typename NamedStructuredOpType>
+void buildNamedStructuredOpRegion(Builder &builder, OperationState &result,
+                                  TypeRange operandTypes,
+                                  TypeRange tensorResultTypes) {
+  Region &region = *result.addRegion();
+  Block *body = new Block();
+  // TODO: atm all operands go through getElementTypeOrSelf,
+  // reconsider when we have evidence we need to.
+  for (auto t : operandTypes)
+    body->addArgument(getElementTypeOrSelf(t));
+  for (auto t : tensorResultTypes)
+    body->addArgument(getElementTypeOrSelf(t));
+  region.push_back(body);
+
+  OpBuilder opBuilder(builder.getContext());
+  opBuilder.setInsertionPointToStart(&region.front());
+  mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
+  NamedStructuredOpType::regionBuilder(*body);
+}
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+  p << op.getOperationName() << ' ';
+  p.printOptionalAttrDict(op.getAttrs());
+  p << ' ' << op.getOperands();
+  p << ": (" << op.getOperandTypes() << ")";
+  auto outputTensorTypes = op.getResultTypes();
+  if (!outputTensorTypes.empty())
+    p << " -> (" << outputTensorTypes << ")";
+}
+
+template <typename NamedStructuredOpType>
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
+
+  // Optional attributes may be added.
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(operandsInfo))
+    return failure();
+
+  SmallVector<Type, 8> operandTypes;
+  if (parser.parseColon() || parser.parseLParen() ||
+      parser.parseTypeList(operandTypes) || parser.parseRParen())
+    return failure();
+
+  // Generic ops may specify that a subset of its outputs are tensors. Such
+  // outputs are specified in the result type.
+  SmallVector<Type, 8> tensorResultTypes;
+  if (parser.parseOptionalArrowTypeList(tensorResultTypes))
+    return failure();
+
+  if (!tensorResultTypes.empty())
+    result.addTypes(tensorResultTypes);
+
+  buildNamedStructuredOpRegion<NamedStructuredOpType>(
+      parser.getBuilder(), result, operandTypes, tensorResultTypes);
+
+  return parser.resolveOperands(operandsInfo, operandTypes,
+                                parser.getCurrentLocation(), result.operands);
+}
+
+template <typename NamedStructuredOpType>
+static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
+  return verifyGenericOp<NamedStructuredOpType>(op);
+}
+
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+
+// TODO: Determine whether we can generate the folders and verifiers.
+LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index b794f54ed5f9..7f39139d69e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -121,8 +121,22 @@ getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
 }
 
 namespace {
+
+// Generic loop emitter, to be specialized on an op-per op basis.
+// TODO: Hook up to named ops interface and, later, retire when all named ops
+// are auto-generated.
 template <typename IndexedValueType, typename LinalgOpType>
-class LinalgScopedEmitter {};
+class LinalgScopedEmitter {
+public:
+  static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                       LinalgOpType linalgOp) {
+    assert(linalgOp.hasBufferSemantics() &&
+           "expected linalg op with buffer semantics");
+    llvm_unreachable("NYI");
+    linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(),
+                                        ScopedContext::getLocation(), allIvs);
+  }
+};
 
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, CopyOp> {

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e6414a0fbd78..585dc36dcaa8 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -48,7 +48,7 @@ func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off +
 // -----
 
 func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}}
+  // expected-error @+1 {{op expected parent op with LinalgOp interface}}
   linalg.yield %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
 }
 
@@ -91,7 +91,7 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
 // -----
 
 func @generic_mismatched_num_returns(%arg0: memref<f32>) {
-  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (0)}}
+  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
@@ -114,6 +114,7 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
     iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i : i32):
+    linalg.yield %i : i32
   }: memref<i32>
 }
 
@@ -128,6 +129,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
     iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i : i32):
+    linalg.yield %i : i32
   }: memref<1xi32>
 }
 
@@ -188,7 +190,8 @@ func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>
 // -----
 
 func @generic_empty_region(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected region with 1 block}}
+  %f0 = constant 0.0: f32
+  // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
   linalg.generic {
     args_in = 1,
     args_out = 1,
@@ -196,7 +199,23 @@ func @generic_empty_region(%arg0: memref<f32>) {
     iterator_types = []
   } %arg0, %arg0 {
     ^bb1:
+      linalg.yield %f0: f32
     ^bb2:
+      linalg.yield %f0: f32
+  }: memref<f32>, memref<f32>
+}
+
+// -----
+
+func @generic_empty_region(%arg0: memref<f32>) {
+  %f0 = constant 0.0: f32
+  // expected-error @+1 {{linalg.generic' op expected region with 1 block}}
+  linalg.generic {
+    args_in = 1,
+    args_out = 1,
+    indexing_maps =  [ affine_map<() -> (0)> ],
+    iterator_types = []
+  } %arg0, %arg0 {
   }: memref<f32>, memref<f32>
 }
 
@@ -210,7 +229,8 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = []
   } %arg0 {
-    ^bb:
+    ^bb(%f: f32, %g: f32):
+      linalg.yield %f: f32
   }: memref<f32>
 }
 
@@ -225,6 +245,7 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
     iterator_types = []
   } %arg0 {
     ^bb(%i: i1):
+    linalg.yield %i : i1
   }: memref<f32>
 }
 
@@ -239,6 +260,7 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
     iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%f: f32):
+      linalg.yield %f : f32
   }: memref<f32>
 }
 
@@ -253,6 +275,7 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
     iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i: f64, %f: f32):
+    linalg.yield %f: f32
   }: memref<f32>
 }
 
@@ -267,6 +290,7 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
     iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i: index, %f: i1):
+    linalg.yield %i: index
   }: memref<f32>
 }
 
@@ -304,7 +328,7 @@ func @indexed_generic_induction_var_arg_type(%arg0: memref<f32>) {
 // -----
 
 func @indexed_generic_result_count(%arg0: memref<?xf32>) {
-  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (2)}}
+  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
   linalg.indexed_generic {
     args_in = 0,
     args_out = 1,
@@ -349,6 +373,38 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 
 // -----
 
+func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+  // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}}
+  %0 = linalg.generic {
+    args_in = 0,
+    args_out = 1,
+    indexing_maps = [ affine_map<(i) -> (i)> ],
+    iterator_types = ["parallel"]
+  } %arg0 {
+    ^bb(%i: f32):
+      linalg.yield %i: f32
+  }: memref<?xf32, affine_map<(i)[off]->(off + i)>> -> f32
+}
+
+// -----
+
+func @generic(%arg0: memref<?x?xi4>) {
+  // expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}}
+  // expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}}
+  linalg.generic  {
+    args_in = 0,
+    args_out = 1,
+    indexing_maps = [ affine_map<(i) -> (i)> ],
+    iterator_types = ["parallel"]
+  } %arg0 {
+    ^bb(%0: i4) :
+      %1 = std.addf %0, %0: i4
+  } : memref<?x?xi4>
+  return
+}
+
+// -----
+
 func @generic_result_0_element_type(%arg0: memref<?xf32>) {
   // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
   linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
@@ -420,3 +476,11 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
     memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
   return
 }
+
+// -----
+
+func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
+  // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref<?x?xf32>'}}
+  linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 89b910e7b04a..249c52af64b1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -620,3 +620,16 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
 //  CHECK-SAME:     memref<?x?x?xf32, #[[strided3D]]> into memref<?x?xf32, #[[strided2D]]>
 //       CHECK:   linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
 //  CHECK-SAME:     memref<?x?xf32, #[[strided2D]]> into memref<?x?x?xf32, #[[strided3D]]>
+
+
+// TODO: Return tensors need a semantics convention update.
+func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
+                %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) {
+  linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  linalg.batch_matmul %ta3, %tb3, %c3 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}
+// CHECK-LABEL: func @named_ops
+//       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.batch_matmul
+

diff  --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
index f06854289abb..0ac97b9291d3 100644
--- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
+++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
@@ -1,6 +1,8 @@
 set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td)
 mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
+# Including Linalg in TableGen requires to depends on generated files
+add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)
 
 set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
 mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
@@ -9,3 +11,5 @@ add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
 set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
 mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
+# Including Linalg in TableGen requires to depends on generated files
+add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 680d3ee28f80..2e7166d60a6b 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -1,75 +1,77 @@
 // RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
 // RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
 
-// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \
-// RUN: | mlir-tblgen -gen-op-decls -I %S/../../include
-
-// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [
-//  ODS-NEXT:   NInputs<2>,
-//  ODS-NEXT:   NOutputs<1>,
-//  ODS-NEXT:   NamedStructuredOpTraits]>
+// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
+//  ODS-NEXT:   NInputs<2>
+//  ODS-NEXT:   NOutputs<1>
+//  ODS-NEXT:   NamedStructuredOpTraits
+//  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  matvec::referenceIterators() {
+// IMPL-LABEL:  Test1Op::referenceIterators() {
 //  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  matvec::referenceIndexingMaps() {
-//       IMPL:  AffineMap::get(2, 0, {d0, d1}),
-//  IMPL-NEXT:  AffineMap::get(2, 0, {d1}),
-//  IMPL-NEXT:  AffineMap::get(2, 0, {d0}) };
+//       IMPL:  Test1Op::referenceIndexingMaps() {
+//       IMPL:  AffineMap::get(2, 0, {d0, d1}, context),
+//  IMPL-NEXT:  AffineMap::get(2, 0, {d1}, context),
+//  IMPL-NEXT:  AffineMap::get(2, 0, {d0}, context) };
 //
-//       IMPL:  matvec::regionBuilder(ArrayRef<BlockArgument> args) {
+//       IMPL:  Test1Op::regionBuilder(Block &block) {
 //       IMPL:  ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
 //       IMPL:  (linalg_yield(ValueRange{ [[e]] }));
 //
-def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
+ods_def<Test1Op> :
+def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
   C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
 }
 
-// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [
-//  ODS-NEXT:   NInputs<2>,
-//  ODS-NEXT:   NOutputs<1>,
-//  ODS-NEXT:   NamedStructuredOpTraits]>
+// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
+//  ODS-NEXT:   NInputs<2>
+//  ODS-NEXT:   NOutputs<1>
+//  ODS-NEXT:   NamedStructuredOpTraits
+//  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  matmul::referenceIterators() {
+// IMPL-LABEL:  Test2Op::referenceIterators() {
 //  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  matmul::referenceIndexingMaps() {
-//       IMPL:  AffineMap::get(3, 0, {d0, d2}),
-//  IMPL-NEXT:  AffineMap::get(3, 0, {d2, d1}),
-//  IMPL-NEXT:  AffineMap::get(3, 0, {d0, d1}) };
+//       IMPL:  Test2Op::referenceIndexingMaps() {
+//       IMPL:  AffineMap::get(3, 0, {d0, d2}, context),
+//  IMPL-NEXT:  AffineMap::get(3, 0, {d2, d1}, context),
+//  IMPL-NEXT:  AffineMap::get(3, 0, {d0, d1}, context) };
 //
-//       IMPL:  matmul::regionBuilder(ArrayRef<BlockArgument> args) {
+//       IMPL:  Test2Op::regionBuilder(Block &block) {
 //       IMPL:  ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
 //       IMPL:  (linalg_yield(ValueRange{ [[e]] }));
 //
-def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
+ods_def<Test2Op> :
+def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
   C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
 }
 
-// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
-//  ODS-NEXT:   NInputs<2>,
-//  ODS-NEXT:   NOutputs<1>,
-//  ODS-NEXT:   NamedStructuredOpTraits]>
+// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
+//  ODS-NEXT:   NInputs<2>
+//  ODS-NEXT:   NOutputs<1>
+//  ODS-NEXT:   NamedStructuredOpTraits
+//  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  batchmatmul::referenceIterators() {
+// IMPL-LABEL:  Test3Op::referenceIterators() {
 //  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  batchmatmul::referenceIndexingMaps() {
-//       IMPL:  AffineMap::get(4, 0, {d0, d1, d3}),
-//  IMPL-NEXT:  AffineMap::get(4, 0, {d3, d2}),
-//  IMPL-NEXT:  AffineMap::get(4, 0, {d0, d1, d2}) };
+//       IMPL:  Test3Op::referenceIndexingMaps() {
+//       IMPL:  AffineMap::get(4, 0, {d0, d1, d3}, context),
+//  IMPL-NEXT:  AffineMap::get(4, 0, {d3, d2}, context),
+//  IMPL-NEXT:  AffineMap::get(4, 0, {d0, d1, d2}, context) };
 //
-//       IMPL:  batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
+//       IMPL:  Test3Op::regionBuilder(Block &block) {
 //       IMPL:  ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
 //       IMPL:  (linalg_yield(ValueRange{ [[e]] }));
 //
-//       TBLGEN: batchmatmulOp
-def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+ods_def<Test3Op> :
+def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
 }

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 93807afa2940..1132806da175 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -90,6 +90,7 @@ class Token {
     // Keywords.
     kw_def,
     FIRST_KEYWORD = kw_def,
+    kw_ods_def,
     kw_floordiv,
     kw_ceildiv,
     kw_mod,
@@ -289,6 +290,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
   StringRef str(tokStart, curPtr - tokStart);
   Token::Kind kind = llvm::StringSwitch<Token::Kind>(str)
                          .Case("def", Token::Kind::kw_def)
+                         .Case("ods_def", Token::Kind::kw_ods_def)
                          .Case("floordiv", Token::Kind::kw_floordiv)
                          .Case("ceildiv", Token::Kind::kw_ceildiv)
                          .Case("mod", Token::Kind::kw_mod)
@@ -896,7 +898,8 @@ struct TensorExpr : public Expression {
   TensorExpr(StringRef name,
              SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
              ArrayRef<unsigned> reductionDims)
-      : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)),
+      : Expression(Kind::TensorExpr), operationName(name),
+        expressions(std::move(exprs)),
         reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
 
   static bool classof(const Expression *e) {
@@ -904,7 +907,7 @@ struct TensorExpr : public Expression {
   }
 
   bool operator==(const TensorExpr &other) const {
-    if (opId != other.opId)
+    if (operationName != other.operationName)
       return false;
     if (expressions.size() != other.expressions.size())
       return false;
@@ -922,7 +925,7 @@ struct TensorExpr : public Expression {
   template <typename Lambda, bool PreOrder>
   void visit(Lambda callback) const;
 
-  StringRef opId;
+  StringRef operationName;
   SmallVector<std::unique_ptr<Expression>, 4> expressions;
   SetVector<unsigned> reductionDimensions;
 };
@@ -988,22 +991,22 @@ class TCParser {
   /// When `gen-impl` is used, this prints the C++ implementation for the extra
   /// methods defined in ODS (referenceIterators, referenceIndexingMaps and
   /// regionBuilder).
-  LogicalResult parseAndEmitTCDef(llvm::raw_ostream &os);
+  LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
 
   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
                 StringRef linalgOpName);
 
   /// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
-  void printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
+  void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
                                ComprehensionParsingState &state);
 
   /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
-  void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
+  void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
                                   ComprehensionParsingState &state);
 
   /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
-  void printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
+  void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
                           ComprehensionParsingState &state);
 
 private:
@@ -1346,7 +1349,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
   return success();
 }
 
-/// Parse and print the information for a TC def.
+/// Parse and print the information for a ODS def.
 ///
 ///   tensor-def-list ::= tensor-def (`,` tensor-def )*
 ///
@@ -1355,16 +1358,29 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
 ///   tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
 ///     `{` comprehension-list `}`
 ///
+///   ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
+///
 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
 /// contain only expressions involving symbols and constants), but can
 /// otherwise contain arbitrary affine expressions.
-LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) {
+LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
+  if (failed(parser.parseToken(Token::Kind::kw_ods_def,
+                               "expected 'ods_def' to define a TC ODS")) ||
+      failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
+    return failure();
+  StringRef cppOpName = parser.curToken.getSpelling();
+  LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
+
+  if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
+      failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
+      failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
+    return failure();
   if (failed(parser.parseToken(Token::Kind::kw_def,
                                "expected 'def' to define a TC")))
     return failure();
 
   StringRef tcName = parser.curToken.getSpelling();
-  LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
     return failure();
@@ -1404,7 +1420,7 @@ LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) {
   SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
   while (parser.curToken.isNot(Token::Kind::r_brace)) {
     perComprehensionStates.push_back(ComprehensionParsingState());
-    if (failed(parseOneComprehension(tcName, tcName,
+    if (failed(parseOneComprehension(cppOpName, tcName,
                                      perComprehensionStates.back())))
       return failure();
   };
@@ -1418,16 +1434,16 @@ LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) {
     return failure();
   }
   if (genODSDecl) {
-    printODS(os, tcName, tcName);
+    printODS(os, cppOpName, tcName);
     os << "\n";
   }
   if (genODSImpl) {
     auto &state = perComprehensionStates.back();
     std::string extraMethods;
     llvm::raw_string_ostream ss(extraMethods);
-    printReferenceIterators(ss, tcName, state);
-    printReferenceIndexingMaps(ss, tcName, state);
-    printRegionBuilder(ss, tcName, state);
+    printReferenceIterators(ss, cppOpName, state);
+    printReferenceIndexingMaps(ss, cppOpName, state);
+    printRegionBuilder(ss, cppOpName, state);
     ss.flush();
     os << extraMethods << "\n";
   }
@@ -1442,18 +1458,32 @@ LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) {
 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
                         StringRef linalgOpName) {
-  const char *header = R"FMT(  def {0}Op : LinalgNamedStructured_Op<"{1}", [
+  const char *header = R"FMT(  def {0} : LinalgNamedStructured_Op<"{1}", [
     NInputs<{2}>,
     NOutputs<{3}>,
-    NamedStructuredOpTraits]> {
+    NamedStructuredOpTraits,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
       let arguments = (ins Variadic<LinalgOperand>:$views);
       let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
+      let regions = (region SizedRegion<1>:$region);
+      let builders = [OpBuilder<
+        "Builder *b, OperationState &result, TypeRange outputTypes, "
+        # "ValueRange views",
+        [{{
+          result.addOperands(views);
+          result.addTypes(outputTypes);
+          buildNamedStructuredOpRegion<{0}>(
+            *b, result, TypeRange(views), outputTypes);
+        }]>
+      ];
+      let parser = [{
+        return ::parseNamedStructuredOp<{0}>(parser, result);
+      }];
       let extraClassDeclaration = [{{
         llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
         llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
-        void regionBuilder(ArrayRef<BlockArgument> args);
+        static void regionBuilder(Block &block);
       }];
-      let hasFolder = 1;
   })FMT";
 
   unsigned nInputs = 0, nOutputs = 0;
@@ -1468,7 +1498,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
 }
 
 /// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
-void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
+void TCParser::printReferenceIterators(llvm::raw_ostream &os,
+                                       StringRef cppOpName,
                                        ComprehensionParsingState &state) {
   const char *referenceReferenceIteratorsFmt =
       R"FMT(
@@ -1498,11 +1529,12 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
       });
   ss.flush();
 
-  os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr);
+  os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
-void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
+void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
+                                          StringRef cppOpName,
                                           ComprehensionParsingState &state) {
   const char *referenceIndexingMapsFmt =
       R"FMT(
@@ -1527,7 +1559,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
     orderedUses[it.second] = it.first;
   llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
     assert(u.indexingMap);
-    const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})";
+    const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
     if (u.indexingMap.isEmpty()) {
       mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
       return;
@@ -1544,11 +1576,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
   });
   mapsStringStream.flush();
 
-  os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr);
+  os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
-void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
+void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
                                   ComprehensionParsingState &state) {
   unsigned count = state.orderedTensorArgs.size();
   llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
@@ -1570,15 +1602,17 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
                             });
       subExprsStringStream.flush();
       const char *tensorExprFmt = "\n    ValueHandle _{0} = {1}({2});";
-      os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs);
+      os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
+                          subExprs);
       subExprsMap[pTensorExpr] = count;
     }
   };
 
   const char *regionBuilderFmt = R"FMT(
-  void {0}::regionBuilder(ArrayRef<BlockArgument> args) {
+  void {0}::regionBuilder(Block &block) {
     using namespace edsc;
     using namespace intrinsics;
+    auto args = block.getArguments();
     ValueHandle {1};
     {2}
     (linalg_yield(ValueRange{ {3} }));
@@ -1612,8 +1646,8 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
   expressionStringStream.flush();
   yieldStringStream.flush();
 
-  os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr,
-                      yieldStr);
+  os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
+                      expressionsStr, yieldStr);
 }
 
 /// Iterate over each Tensor Comprehension def.
@@ -1621,7 +1655,7 @@ LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
                                                   Parser &parser) {
   while (parser.curToken.getKind() != Token::Kind::eof) {
     TCParser tcParser(parser);
-    if (failed(tcParser.parseAndEmitTCDef(os)))
+    if (failed(tcParser.parseAndEmitODSDef(os)))
       return failure();
   }
   return success();


        


More information about the Mlir-commits mailing list