[Mlir-commits] [mlir] 93fd30b - [mlir][Linalg] Evolve named ops to use assembly form and support linalg on tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 18 03:17:37 PDT 2020


Author: Nicolas Vasilache
Date: 2020-09-18T06:14:30-04:00
New Revision: 93fd30bac3345fea4f5beba3241f1ef4f2f5f419

URL: https://github.com/llvm/llvm-project/commit/93fd30bac3345fea4f5beba3241f1ef4f2f5f419
DIFF: https://github.com/llvm/llvm-project/commit/93fd30bac3345fea4f5beba3241f1ef4f2f5f419.diff

LOG: [mlir][Linalg] Evolve named ops to use assembly form and support linalg on tensors.

This revision allows representing a reduction at the level of linalg on tensors for named ops. When a structured op has a reduction and returns tensor(s), new conventions are added and documented.

As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:

```
  linalg.matmul ins(%a, %b : memref<?x?xf32>, tensor<?x?xf32>)
               outs(%c : memref<?x?xf32>)
```

, whereas the syntax for a `linalg.matmul` returning a new tensor is:

```
  %d = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
                    init(%c : memref<?x?xf32>)
                      -> tensor<?x?xf32>
```

Other parts of linalg will be extended accordingly to allow mixed buffer/tensor semantics in the presence of reductions.

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
    mlir/include/mlir/IR/OpBase.td
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
    mlir/test/Dialect/Linalg/affine.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
    mlir/test/Dialect/Linalg/fusion-2-level.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/promote.mlir
    mlir/test/Dialect/Linalg/promotion_options.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/standard.mlir
    mlir/test/Dialect/Linalg/tile-and-distribute.mlir
    mlir/test/Dialect/Linalg/tile.mlir
    mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/IR/slice.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 7ae1b73f48a7..140197b16815 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -40,7 +40,8 @@ generic properties that enable [key transformations](#key_transformations),
 including lowering to scalar load/store and other operations or to external
 library calls and intrinsics.
 
-These ops can have ***either tensor or buffer operands***.
+These ops can have ***either tensor or buffer operands***, subject to
+[conventions and limitations](#tensors_and_buffers).
 
 ### Payload-Carrying Ops<a name="payload_ops"></a>
 Linalg defines two payload carrying operations that implement the [structured ops](
@@ -463,6 +464,76 @@ because of empirical evidence building and working on multiple high-level
 compilers. As we lay those down and engage more with the community, we expect
 multiple rounds of discussions and design changes to the original architecture.
 
+### Tensors and Buffers: Conventions and Limitations <a name="tensors_and_buffers"></a>
+
+Tensors are immutable SSA values, buffers are mutable regions of memory subject
+to side-effects and aliasing. As a consequence, output buffers are passed as
+operands whereas output tensors are new SSA values corresponding to op results.
+Inputs can be arbitrary tensors or buffers and are always passed as operands.
+
+The following convention is currently in-flight and is in the process of
+replacing other existing conventions. The following convention currently applies
+to "named" structured ops which are auto-generated by the linalg-ods tool.
+
+The convention adopted is as follows:
+
+1.  A first block of `ins` op operands hold read-only inputs of ShapedType.
+2.  An optional second block of `outs` op operands hold read-write output
+    buffers of MemRefType.
+3.  An optional third block of `init` operands hold initialization tensors of
+    RankedTensorType. Such tensors can appear when the op performs a reduction
+    and returns a tensor.
+
+Structured ops with fully parallel semantics, have empty `init`. They may either
+write in-place into `outs` buffers or return new tensors.
+
+Structured ops with reduction semantics and output tensor(s) however have
+additional restrictions:
+
+1.  They can only return a single tensor for now.
+2.  They cannot have any output buffer operand (i.e. `outs` is empty).
+3.  They have exactly one `init` tensor of the same type as the unique output
+    tensor. Such an `init` tensor does not have an explicit associate indexing
+    map. Instead the map of the result tensor is used to signify that the `init`
+    and the `result` are "tied".
+
+Points 1. and 2. keep complexity of the representation in check by allowing only
+a single result tensor, when reductions are present.
+
+Point 3. is related to the fact that SSA values cannot represent in-place
+updates. Instead, linalg adopts a similar convention that exists in e.g.
+`vector.outerproduct`: the value that is reduced into is passed as an explicit
+argument and a new result of the same shape is produced.
+
+It is expected buffer allocation will fold this last input onto the result in a
+single output buffer argument, which is why the same indexing map is required:
+the last input operand is said to be "tied" to the result.
+
+Alternative, more complex representations, would allow for:
+
+1.  Multiple results and `init` tensors in arbitrary orders, which could be
+    captured by an extra ArrayAttr of position pairs.
+2.  Relaxing the conditions on the indexing map equalities on the each pair and
+    e.g. allow implicit broadcasts of the input.
+
+These representations are deemed unnecessarily complex for now and are left for
+future discussion.
+
+As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:
+
+```
+linalg.matmul ins(%a, %b : memref<?x?xf32>, tensor<?x?xf32>)
+             outs(%c : memref<?x?xf32>)
+```
+
+, whereas the syntax for a `linalg.matmul` returning a new tensor is:
+
+```
+%d = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+                  init(%c : tensor<?x?xf32>)
+                    -> tensor<?x?xf32>
+```
+
 ### Data Representation: Views<a name="views"></a>
 The current implementation uses the [Strided MemRef (a.k.a View)](
 https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio)
@@ -570,10 +641,10 @@ When `mlir-linalg-ods-gen -gen-ods-decl=1` is called, the following ODS is
 produced:
 
 ```
-  def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
-    NInputs<2>,
-    NOutputs<1>,
-    NamedStructuredOpTraits]> { ... }
+def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
+  NInputs<2>,
+  NOutputs<1>,
+  NamedStructuredOpTrait]> { ... }
 ```
 
 When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced:

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 1b1a2125e95d..6e4a35035110 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -699,6 +699,14 @@ The available directives are as follows:
     -   `input` must be either an operand or result [variable](#variables), the
         `operands` directive, or the `results` directive.
 
+*   `type_ref` ( input )
+
+    -   Represents a reference to the type of the given input that must have
+        already been resolved.
+    -   `input` must be either an operand or result [variable](#variables), the
+        `operands` directive, or the `results` directive.
+    -   Used to pass previously parsed types to custom directives.
+
 #### Literals
 
 A literal is either a keyword or punctuation surrounded by \`\`.
@@ -762,6 +770,10 @@ declarative parameter to `parse` method argument is detailed below:
     -   Single: `Type &`
     -   Optional: `Type &`
     -   Variadic: `SmallVectorImpl<Type> &`
+*   TypeRef Directives
+    -   Single: `Type`
+    -   Optional: `Type`
+    -   Variadic: `const SmallVectorImpl<Type> &`
 
 When a variable is optional, the value should only be specified if the variable
 is present. Otherwise, the value should remain `None` or null.
@@ -788,6 +800,10 @@ declarative parameter to `print` method argument is detailed below:
     -   Single: `Type`
     -   Optional: `Type`
     -   Variadic: `TypeRange`
+*   TypeRef Directives
+    -   Single: `Type`
+    -   Optional: `Type`
+    -   Variadic: `TypeRange`
 
 When a variable is optional, the provided value may be null.
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 09fc11bc4917..a35964c5eab4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -85,6 +85,11 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
+/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
+/// Assumes `op` is a LinalgOp.
+void getDimsOfType(Operation *op, StringRef iteratorTypeName,
+                   SmallVectorImpl<AffineExpr> &res);
+
 } // namespace linalg
 } // namespace mlir
 
@@ -96,5 +101,4 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
 
-
 #endif // MLIR_DIALECT_LINALG_LINALGOPS_H_

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 41beab059008..a6c0d16a9ee2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -32,6 +32,7 @@ class NOutputs<int args_out> :
   NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
 
 def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
+def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
 
 // Base Tablegen class for Linalg ops.
 // Linalg ops that correspond to library calls operate on linalg::View as their
@@ -798,24 +799,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
 
-class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
-    : LinalgStructuredBase_Op<mnemonic, props> {
-  string spec = ?;
-  // We cannot use an assemblyFormat atm because we need to hook in a custom-
-  // built implicit region from a static OpClass method.
-  // TODO: Revisit in the future if/when appropriate.
-  // let assemblyFormat = "`(` operands `)` attr-dict `:` "
-  //   "functional-type(operands, results)";
-
-  // The parser needs to specialize on the OpType so it has to be auto-generated
-  // in the linalg-ods tool.
-  let printer = [{ return ::printNamedStructuredOp(p, *this); }];
-  let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
-  let hasFolder = 1;
-  let hasCanonicalizer = 1;
-}
-
-// This file is auto-generated from a tc specification.
+// This file is auto-generated from a TC def specification.
 include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
 
 #endif // LINALG_STRUCTURED_OPS

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 0e8216cc4268..1e0e85f82c7f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -25,7 +25,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     //===------------------------------------------------------------------===//
     InterfaceMethod<
       /*desc=*/[{
-        Return the number of parallel loops within the current operation.
+        Return the number of parallel loops.
       }],
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumParallelLoops",
@@ -38,7 +38,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the number of reduction loops within the current operation.
+        Return the dims that are parallel loops.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getParallelDims",
+      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return getDimsOfType($_op, getParallelIteratorTypeName(), res);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of reduction loops.
       }],
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumReductionLoops",
@@ -51,7 +63,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the number of window loops within the current operation.
+        Return the dims that are reduction loops.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getReductionDims",
+      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return getDimsOfType($_op, getReductionIteratorTypeName(), res);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of window loops.
       }],
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumWindowLoops",
@@ -62,6 +86,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
                                $_op.iterator_types());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the dims that are window loops.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getWindowDims",
+      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res);
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the total number of loops within the current operation.
@@ -99,14 +135,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     //   linalg.indexed_generic ops).
     InterfaceMethod<
       /*desc=*/[{
-        Return the number of inputs from the current operation.
+        Return the number of inputs.
       }],
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumInputs"
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the number of outputs from the current operation.
+        Return the number of outputs.
       }],
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumOutputs"
@@ -160,7 +196,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the input operands from the current operation.
+        Return the input operands.
       }],
       /*retTy=*/"Operation::operand_range",
       /*methodName=*/"getInputs",
@@ -187,7 +223,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return res;
       }]
     >,
-
     //===------------------------------------------------------------------===//
     // Output arguments handling.
     //===------------------------------------------------------------------===//
@@ -267,7 +302,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }]>,
     InterfaceMethod<
       /*desc=*/[{
-        Return the output buffers (operands) from the current operation.
+        Return the output buffers (operands).
       }],
       /*retTy=*/"Operation::operand_range",
       /*methodName=*/"getOutputBuffers",
@@ -354,7 +389,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
           return getInputShapedType(i);
         if (i < getNumInputsAndOutputBuffers())
           return getOutputBufferType(i - $_op.getNumInputs());
-        return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
+        return this->getOperation()->getResult(
+          i - getNumInputsAndOutputBuffers()).
+          getType().template cast<ShapedType>();
       }]>,
     InterfaceMethod<
       /*desc=*/[{
@@ -408,11 +445,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::to_vector<4>(
-            llvm::map_range($_op.indexing_maps(),
-                            [](Attribute attr) -> AffineMap {
-              return attr.cast<AffineMapAttr>().getValue();
-            }));
+        return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange<AffineMapAttr>());
       }]
     >,
     InterfaceMethod<
@@ -425,10 +458,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(i < getNumInputsAndOutputs());
-        return $_op.indexing_maps()
-            .getValue()[i]
-            .template cast<AffineMapAttr>()
-            .getValue();
+        return getIndexingMaps()[i];
       }]
     >,
     InterfaceMethod<
@@ -441,10 +471,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(i < $_op.getNumInputs());
-        return $_op.indexing_maps()
-            .getValue()[i]
-            .template cast<AffineMapAttr>()
-            .getValue();
+        return getIndexingMaps()[i];
       }]
     >,
     InterfaceMethod<
@@ -457,10 +484,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(i < $_op.getNumOutputs());
-        return $_op.indexing_maps()
-            .getValue()[i + $_op.getNumInputs()]
-            .template cast<AffineMapAttr>()
-            .getValue();
+        return getIndexingMaps()[i + $_op.getNumInputs()];
       }]
     >,
     InterfaceMethod<

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index c4790ca617f1..ae56dd66f57a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -71,6 +71,80 @@ class StructuredOpTraits
   }
 };
 
+/// This class provides a verifier for structured ops that are known to operate
+/// on buffers or tensors and that support `ins`, `outs` and `init` arguments.
+/// This trait must be used in conjunction with an op definition or a trait that
+/// provides the methods `getNumInputs` and `getNumOutputs`.
+///
+/// Use as a trait as follows:
+///
+///   class MatmulOp : public Op<MatmulOp, OpTrait::NamedStructuredOpTrait> {
+///
+template <typename ConcreteType>
+class NamedStructuredOpTrait
+    : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTrait> {
+public:
+  unsigned getNumInputs() {
+    return cast<ConcreteType>(this->getOperation()).inputs().size();
+  }
+  unsigned getNumOutputs() {
+    ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
+    return concreteOp.output_buffers().size() +
+           concreteOp.output_tensors().size();
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    ConcreteType concreteOp = cast<ConcreteType>(op);
+    unsigned nInputAndBufferOperands =
+        concreteOp.getNumInputsAndOutputBuffers();
+    if (failed(
+            OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
+      return failure();
+
+    SmallVector<AffineExpr, 4> redDims;
+    concreteOp.getReductionDims(redDims);
+    // If no result and no reduction, only check there is no init tensor and we
+    // are done.
+    if (redDims.empty() || op->getNumResults() == 0) {
+      if (!concreteOp.init_tensors().empty())
+        return op->emitError("expected empty `init` when op has no "
+                             "results or no reduction dims");
+      return success();
+    }
+
+    // Only a single tensor result supported atm.
+    if (op->getNumResults() != 1)
+      return op->emitError(
+          "expected single tensor result when reduction present");
+
+    if (concreteOp.init_tensors().size() != op->getNumResults())
+      return op->emitError(
+          "expected #init tensors to match #results when reduction present");
+
+    for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx)
+      if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx])
+        return op->emitError("expected init tensor #")
+               << idx << " of the same type as result #" << idx;
+
+    // Output tensor indexing map may not depend on reduction index.
+    // TODO: this is not yet tested. Add a test when linalg.generic switches to
+    // this representation.
+    for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) {
+      AffineMap outputMap = concreteOp.getOutputIndexingMap(idx);
+      for (auto expr : outputMap.getResults()) {
+        for (auto dim : redDims) {
+          unsigned pos = dim.cast<AffineDimExpr>().getPosition();
+          if (expr.isFunctionOfDim(pos))
+            return op->emitError(
+                       "unexpected single tensor output indexing map ")
+                   << "is function of reduction dim @" << pos;
+        }
+      }
+    }
+
+    return success();
+  }
+};
+
 } // namespace linalg
 } // namespace OpTrait
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index b038819bca3d..c9103a2b8b63 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -15,8 +15,6 @@
 
 include "mlir/IR/OpBase.td"
 
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
 //===----------------------------------------------------------------------===//
 // Shape Inference dialect definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ec0e229ae627..d314393caae2 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -571,6 +571,10 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 
 def AnyVector : VectorOf<[AnyType]>;
 
+// Shaped types.
+
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+
 // Tensor types.
 
 // Any tensor type whose element type is from the given `allowedTypes` list

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
index 8f3c6df79f90..97ea95c8bcd1 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_1d_filled_f32(%s1 : index, %f : f32) -> memref<?xf32> {
 }
 
 func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
-  linalg.conv_1d %arg0, %arg1, %arg2 : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+  linalg.conv_1d ins (%arg0, %arg1: memref<?xf32>, memref<?xf32>)
+                outs (%arg2: memref<?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
index 46634a7e5921..dcfcc9b62bbc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
 }
 
 func @conv_1d_ncw(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
-  linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  linalg.conv_1d_ncw ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                    outs (%arg2: memref<?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
index a6aeb30fc153..2e79b46801bc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
 }
 
 func @conv_1d_nwc(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
-  linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  linalg.conv_1d_nwc ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                    outs (%arg2: memref<?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
index 819d95ef5da0..e271b0a009b6 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_2d_filled_f32(%s1 : index, %s2 : index, %f : f32) -> memref<?x?xf32>
 }
 
 func @conv_2d(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
-  linalg.conv_2d %arg0, %arg1, %arg2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.conv_2d ins (%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+                outs (%arg2: memref<?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
index fb0e70861864..e27c40524fcc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
 }
 
 func @conv_2d_nchw(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+  linalg.conv_2d_nchw ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+                     outs (%arg2: memref<?x?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
index 5888eec7d67a..b5b4a5c82c09 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
 }
 
 func @conv_2d_nhwc(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+  linalg.conv_2d_nhwc ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+                     outs (%arg2: memref<?x?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
index f0ca37f86fcd..12ea94696660 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
 }
 
 func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
-  linalg.conv_3d %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  linalg.conv_3d ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                outs (%arg2: memref<?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
index a56a260b9cd8..e36abc83b700 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s
 }
 
 func @conv_3d_ncdhw(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
-  linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+  linalg.conv_3d_ncdhw ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+                      outs (%arg2: memref<?x?x?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
index 37fc6453e5dd..b302b3e0d8bd 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s
 }
 
 func @conv_3d_ndhwc(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
-  linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+  linalg.conv_3d_ndhwc ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+                      outs (%arg2: memref<?x?x?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index efe2e45f78ea..7b9ba74f5492 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Support/LLVM.h"
 
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -35,13 +36,29 @@ using namespace mlir::linalg;
 /// Forward declarations.
 template <typename NamedStructuredOpType>
 static void buildNamedStructuredOpRegionAndAttributes(
-    Builder &builder, OperationState &result, TypeRange operandTypes,
-    TypeRange tensorResultTypes);
+    OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
+    TypeRange outputBufferTypes, TypeRange initTensorTypes,
+    TypeRange resultTypes);
+
 template <typename NamedStructuredOpType>
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
+                             TypeRange inputTypes, TypeRange outputBufferTypes,
+                             TypeRange initTensorTypes, TypeRange resultTypes);
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+                              SmallVectorImpl<Type> &resultTypes);
+
 template <typename NamedStructuredOpType>
 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result);
+
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+                                          TypeRange resultTypes);
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+
 template <typename NamedStructuredOpType>
 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
 
@@ -248,11 +265,6 @@ template <typename GenericOpType>
 static LogicalResult verifyGenericOp(GenericOpType op) {
   auto nInputViews = op.getNumInputs();
   auto nLoops = op.getNumLoops();
-  auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
-  if (nInputsAndOutputBuffers != llvm::size(op.views()))
-    return op.emitOpError("expected exactly ")
-           << nInputsAndOutputBuffers
-           << " inputs (tensor or buffer) and output buffer operands";
 
   auto &region = op.region();
   if (!llvm::hasSingleElement(region))
@@ -302,8 +314,27 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
   return success();
 }
 
-static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
-static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
+static LogicalResult verify(GenericOp op) {
+  // Temporarily hoisted here to avoid duplicating more code.
+  // TODO: uniformize with named structured ops.
+  auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+  if (nInputsAndOutputBuffers != llvm::size(op.views()))
+    return op.emitOpError("expected exactly ")
+           << nInputsAndOutputBuffers
+           << " inputs (tensor or buffer) and output buffer operands";
+  return verifyGenericOp(op);
+}
+
+static LogicalResult verify(IndexedGenericOp op) {
+  // Temporarily hoisted here to avoid duplicating more code.
+  // TODO: uniformize with named structured ops.
+  auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+  if (nInputsAndOutputBuffers != llvm::size(op.views()))
+    return op.emitOpError("expected exactly ")
+           << nInputsAndOutputBuffers
+           << " inputs (tensor or buffer) and output buffer operands";
+  return verifyGenericOp(op);
+}
 
 //===----------------------------------------------------------------------===//
 // ReshapeOp
@@ -1098,12 +1129,28 @@ static LogicalResult verify(PoolingSumOp op) {
 
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
 
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
 
+/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
+/// Assumes `op` is a LinalgOp.
+void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
+                                 SmallVectorImpl<AffineExpr> &res) {
+  unsigned dim = 0;
+  MLIRContext *ctx = op->getContext();
+  for (auto tn :
+       cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
+    if (tn == iteratorTypeName)
+      res.push_back(getAffineDimExpr(dim, ctx));
+    ++dim;
+  }
+}
+
 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
                                              unsigned rank,
                                              MLIRContext *context) {
@@ -1196,8 +1243,8 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 }
 
 // TODO: Consider making all this boilerplate easy to autogenerate
-// with Tablegen. This seems a desirable property in the context of OpInterfaces
-// where a Linalg "named" op **isa** LinalgOp.
+// with Tablegen. This seems a desirable property in the context of
+// OpInterfaces where a Linalg "named" op **isa** LinalgOp.
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1222,23 +1269,28 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
 //===----------------------------------------------------------------------===//
 
 template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
-                                               OperationState &result,
-                                               TypeRange operandTypes,
-                                               TypeRange tensorResultTypes) {
-  Region &region = *result.addRegion();
-  Block *body = new Block();
+static void buildNamedStructuredOpRegionAndAttributesImpl(
+    OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
+    TypeRange outputBufferTypes, TypeRange initTensorTypes,
+    TypeRange resultTypes,
+    std::function<void(unsigned, unsigned)> errorHandler) {
   // TODO: atm all operands go through getElementTypeOrSelf,
   // reconsider when we have evidence we need to.
-  for (auto t : operandTypes)
-    body->addArgument(getElementTypeOrSelf(t));
-  for (auto t : tensorResultTypes)
-    body->addArgument(getElementTypeOrSelf(t));
-  region.push_back(body);
-
-  OpBuilder opBuilder(builder.getContext());
-  opBuilder.setInsertionPointToStart(&region.front());
-  mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
+  SmallVector<Type, 8> argTypes;
+  for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
+    for (auto t : containers)
+      argTypes.push_back(getElementTypeOrSelf(t));
+
+  // RAII.
+  OpBuilder::InsertionGuard guard(opBuilder);
+  Block *body = opBuilder.createBlock(&region, {}, argTypes);
+  unsigned actual = body->getNumArguments();
+  unsigned expected = NamedStructuredOpType::getNumRegionArgs();
+  if (expected != actual)
+    return errorHandler(expected, actual);
+
+  opBuilder.setInsertionPointToStart(body);
+  mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
   NamedStructuredOpType::regionBuilder(*body);
 
   // indexing_maps is an auto-generated method.
@@ -1247,59 +1299,133 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
 }
 
 template <typename NamedStructuredOpType>
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
-  std::array<StringRef, 2> silentAttrNames{getIndexingMapsAttrName(),
-                                           getIteratorTypesAttrName()};
-  p << op.getOperationName() << ' ';
-  p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
-  p << ' ' << op.getOperands();
-  p << " : (" << op.getOperandTypes() << ")";
-  auto outputTensorTypes = op.getResultTypes();
-  if (!outputTensorTypes.empty())
-    p << " -> (" << outputTensorTypes << ")";
+void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+                                               OperationState &result,
+                                               TypeRange inputTypes,
+                                               TypeRange outputBufferTypes,
+                                               TypeRange initTensorTypes,
+                                               TypeRange resultTypes) {
+  Region &region = *result.addRegion();
+  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+      resultTypes, [&](unsigned expected, unsigned actual) {
+        llvm::errs() << "region expects " << expected << " args, got "
+                     << actual;
+        assert(expected != actual && "incorrect number of arguments");
+      });
+}
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
+                             TypeRange inputTypes, TypeRange outputBufferTypes,
+                             TypeRange initTensorTypes, TypeRange resultTypes) {
+  ParseResult res = success();
+  OpBuilder opBuilder(parser.getBuilder().getContext());
+  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+      resultTypes, [&](unsigned expected, unsigned actual) {
+        res = parser.emitError(parser.getCurrentLocation(),
+                               llvm::formatv("region expects {0} args, got {1}",
+                                             expected, actual));
+      });
+  return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+                              SmallVectorImpl<Type> &resultTypes) {
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(resultTypes))
+      return failure();
+  return success();
 }
 
 template <typename NamedStructuredOpType>
 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
-  result.getContext()->getOrLoadDialect<StandardOpsDialect>();
+  llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
+      initTensorsOperandsLoc;
+  SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
+      outputBuffersOperands, initTensorsOperands;
+  SmallVector<Type, 1> inputsTypes, outputBuffersTypes, initTensorsTypes,
+      outputTensorsTypes;
+  std::unique_ptr<Region> regionRegion = std::make_unique<Region>();
 
-  // Optional attributes may be added.
-  if (parser.parseOperandList(operandsInfo) ||
-      parser.parseOptionalAttrDict(result.attributes))
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseKeyword("ins") || parser.parseLParen())
     return failure();
 
-  SmallVector<Type, 8> operandTypes;
-  if (parser.parseColon() || parser.parseLParen() ||
-      parser.parseTypeList(operandTypes) || parser.parseRParen())
+  inputsOperandsLoc = parser.getCurrentLocation();
+  if (parser.parseOperandList(inputsOperands) || parser.parseColon() ||
+      parser.parseTypeList(inputsTypes) || parser.parseRParen())
     return failure();
 
-  // Generic ops may specify that a subset of its outputs are tensors. Such
-  // outputs are specified in the result type.
-  SmallVector<Type, 8> tensorResultTypes;
-  if (parser.parseOptionalArrowTypeList(tensorResultTypes))
+  if (succeeded(parser.parseOptionalKeyword("outs"))) {
+    outputBuffersOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseLParen() ||
+        parser.parseOperandList(outputBuffersOperands) || parser.parseColon() ||
+        parser.parseTypeList(outputBuffersTypes) || parser.parseRParen())
+      return failure();
+  }
+  if (succeeded(parser.parseOptionalKeyword("init"))) {
+    initTensorsOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
+        parser.parseColon() || parser.parseTypeList(initTensorsTypes) ||
+        parser.parseRParen())
+      return failure();
+  }
+
+  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
     return failure();
 
-  if (!tensorResultTypes.empty())
-    result.addTypes(tensorResultTypes);
+  if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
+          parser, *regionRegion, inputsTypes, outputBuffersTypes,
+          initTensorsTypes, outputTensorsTypes))
+    return failure();
 
-  // The number of parsed arguments must equal
-  // the number of expected arguments for the current operation.
-  auto parsedArgs = operandsInfo.size();
-  auto expectedArgs = NamedStructuredOpType::getNumInputs() +
-                      NamedStructuredOpType::getNumOutputs();
-  if (parsedArgs != expectedArgs)
-    return parser.emitError(parser.getNameLoc(),
-                            "expects " + std::to_string(expectedArgs) +
-                                " operands, but found " +
-                                std::to_string(parsedArgs));
+  if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
+                             result.operands) ||
+      parser.resolveOperands(outputBuffersOperands, outputBuffersTypes,
+                             outputBuffersOperandsLoc, result.operands) ||
+      parser.resolveOperands(initTensorsOperands, initTensorsTypes,
+                             initTensorsOperandsLoc, result.operands))
+    return failure();
 
-  buildNamedStructuredOpRegionAndAttributes<NamedStructuredOpType>(
-      parser.getBuilder(), result, operandTypes, tensorResultTypes);
+  result.addTypes(outputTensorsTypes);
+  result.addRegion(std::move(regionRegion));
+  result.addAttribute("operand_segment_sizes",
+                      parser.getBuilder().getI32VectorAttr(
+                          {static_cast<int32_t>(inputsOperands.size()),
+                           static_cast<int32_t>(outputBuffersOperands.size()),
+                           static_cast<int32_t>(initTensorsOperands.size())}));
+  return success();
+}
 
-  return parser.resolveOperands(operandsInfo, operandTypes,
-                                parser.getCurrentLocation(), result.operands);
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+                                          TypeRange resultTypes) {
+  if (resultTypes.empty())
+    return;
+  p << "-> " << resultTypes;
+}
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+  p << op.getOperationName();
+  p.printOptionalAttrDict(op.getAttrs(),
+                          /*elidedAttrs=*/{"operand_segment_sizes"});
+  p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+  if (!op.output_buffers().empty())
+    p << " outs(" << op.output_buffers() << " : "
+      << op.output_buffers().getTypes() << ")";
+  if (!op.init_tensors().empty())
+    p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
+      << ")";
+  p << " ";
+  printNamedStructuredOpResults(p, op.output_tensors().getTypes());
+  p << " ";
+
+  // Region is elided.
 }
 
 template <typename NamedStructuredOpType>
@@ -1354,8 +1480,6 @@ CANONICALIZERS_AND_FOLDERS(FillOp)
 CANONICALIZERS_AND_FOLDERS(GenericOp)
 CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
 
-#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
-
 // TODO: Determine whether we can generate the folders and verifiers.
 CANONICALIZERS_AND_FOLDERS(BatchMatmulOp)
 CANONICALIZERS_AND_FOLDERS(DotOp)

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index abc82f300f63..4dfc3d605570 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -58,6 +58,8 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
 //===----------------------------------------------------------------------===//
 
 void mlir::linalg::LinalgDialect::initialize() {
+  getContext()->getOrLoadDialect("std");
+
   addTypes<RangeType>();
   addOperations<
 #define GET_OP_LIST
@@ -67,6 +69,7 @@ void mlir::linalg::LinalgDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
       >();
+
   addInterfaces<LinalgInlinerInterface>();
 }
 

diff  --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
index c2e8a31eb443..eeb2ca31fd2a 100644
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -8,7 +8,8 @@
 // CHECK-DAG:  #[[$map5:.*]] = affine_map<(d0) -> (d0)>
 
 func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
-  linalg.conv_1d %arg0, %arg1, %arg2 : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+  linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
+                outs(%arg2 : memref<?xf32>)
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index 13f37d844b8a..0df7db06e4c4 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -15,7 +15,8 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
   %B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
   %C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%C: memref<?x?xf32>)
   return
 }
 
@@ -102,7 +103,8 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
 // Named ops to loops.
 //----------------------------------------------------------------------------//
 func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
-  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                     outs(%C : memref<?x?x?xf32>)
   return
 }
 // CHECK-LABEL: @named_batch_matmul

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 85321084cd0c..23e9f1784541 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -14,8 +14,9 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
   // CHECK:  linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
   %4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
 
-  // CHECK:  linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>)
-  linalg.matmul %3, %3, %3 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  // CHECK:  linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
+  linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%3: memref<?x?xf32>)
   return %4: memref<?x?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
index 84c56ee3d840..72d76b3d1869 100644
--- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
+++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns
-//| FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s
 
 // CHECK-LABEL: scf_for
 func @scf_for(%A : memref<i64>, %step : index) {

diff  --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 44dd268998d2..0c9b9ca0dca7 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -12,7 +12,8 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
   %0 = dim %C, %c0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+               outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>)
   scf.for %arg5 = %c0 to %0 step %c20 {
     scf.for %arg6 = %c0 to %2 step %c30 {
       scf.for %arg7 = %c0 to %1 step %c40 {
@@ -28,7 +29,8 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
               %14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               %16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               %17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-              linalg.matmul %14, %16, %17 : (memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
+              linalg.matmul ins(%14, %16: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                           outs(%17: memref<?x?xf32, offset: ?, strides: [?, ?]>)
             }
           }
         }

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 7fb5a7ab4e85..38e1b43be668 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -14,10 +14,9 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
   %0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, 1]>
   %1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
   %2 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, 1]>,
-     memref<?x?xf32, offset: 0, strides: [?, 1]>,
-     memref<?x?xf32, offset: 0, strides: [?, 1]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, 1]>,
+                             memref<?x?xf32, offset: 0, strides: [?, 1]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, 1]>)
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
       scf.for %arg7 = %c0 to %1 step %c4 {
@@ -30,10 +29,9 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
         %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, 1]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8: memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -61,10 +59,9 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C: memref<?x?xf32, offset: 0, strides: [?, ?]>)
   %0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -80,10 +77,9 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -113,10 +109,9 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   %0 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -132,10 +127,9 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -165,14 +159,12 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   %0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -188,10 +180,9 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -227,14 +218,12 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %0 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %C, %B, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%C, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   scf.for %arg5 = %c0 to %1 step %c2 {
     scf.for %arg6 = %c0 to %0 step %c3 {
       scf.for %arg7 = %c0 to %2 step %c4 {
@@ -247,10 +236,9 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -275,9 +263,9 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 // CHECK-DAG:    %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
 // CHECK-DAG:    %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
 // CHECK-DAG:    %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK:        linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]]
-// CHECK:        linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]]
-// CHECK:        linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]]
+// CHECK:        linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
+// CHECK:        linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
+// CHECK:        linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
 
 // -----
 
@@ -297,14 +285,12 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c3 = constant 3 : index
   %c2 = constant 2 : index
   %0 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %C, %E :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%E : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   %1 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %1 step %c2 {
@@ -322,10 +308,9 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -359,14 +344,12 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %3 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %4 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %C, %E :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%E : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
       scf.for %arg7 = %c0 to %1 step %c4 {
@@ -379,10 +362,9 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %7, %9, %10 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%7, %9 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%10 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -398,10 +380,9 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %7, %9, %10 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%7, %9 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%10 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -414,7 +395,7 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 // CHECK:  %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  linalg.matmul %[[A]], %[[C]], %[[E]]
+// CHECK:  linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
 // CHECK:  scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
 // CHECK:    scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
 // CHECK:      scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
@@ -445,14 +426,12 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c2 = constant 2 : index
   %0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %C, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
@@ -469,10 +448,9 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -742,10 +720,9 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
   %B = alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %C = alloc(%dim, %dim)[%s0, %s1]  : memref<?x?xf32, offset: 0, strides: [?, ?]>
 
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+                             memref<?x?xf32, offset: 0, strides: [?, ?]>)
+               outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
 
   scf.for %i = %c0 to %dim step %c2 {
     scf.for %j = %c0 to %dim step %c3 {
@@ -759,10 +736,9 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
         %2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %0, %1, %2 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul ins(%0, %1 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, ?]>)
+                     outs(%2 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 3774aed7ad1f..dce5c21db252 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -428,13 +428,6 @@ func @generic(%arg0: memref<?x?xi4>) {
 
 // -----
 
-func @generic_result_0_element_type(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}}
-  linalg.dot %arg0, %arg0 : (memref<?xf32>, memref<?xf32>)
-}
-
-// -----
-
 func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
   // expected-error @+1 {{expects memref ranks to be greater than 2}}
   linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
@@ -511,7 +504,8 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
 
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
   // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref<?x?xf32>'}}
-  linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
+  linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
+                     outs(%c3 : memref<?x?x?xf32>)
   return
 }
 
@@ -531,3 +525,52 @@ func @generic(%arg0: tensor<?x?xi4>) {
   } : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
   return
 }
+
+// -----
+
+func @empty_init_expected(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+  // expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}}
+  linalg.matmul ins(%m, %m: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%m : memref<?x?xf32>)
+               init(%t : tensor<?x?xf32>)
+  return
+}
+
+// -----
+
+func @incorrect_region_arg_count(%m: memref<?x?xf32>) {
+  // expected-error @+3 {{region expects 3 args, got 4}}
+  %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+                       -> tensor<?x?xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
+func @single_tensor_result(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+  // expected-error @+1 {{expected single tensor result when reduction present}}
+  %res:2 = linalg.matmul ins(%m : memref<?x?xf32>)
+                        init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
+                          -> tensor<?x?xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
+func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+  // expected-error @+1 {{expected #init tensors to match #results when reduction present}}
+  %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+                      init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
+                        -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+  // expected-error @+1 {{expected init tensor #0 of the same type as result #0}}
+  %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+                      init(%t : tensor<?x?xf32>)
+                        -> tensor<?xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 1e10e036ee2d..04ca27b8e175 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -39,7 +39,8 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
   %B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
   %C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%C: memref<?x?xf32>)
   return
 }
 // CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
@@ -83,7 +84,8 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
   %2 = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
   %3 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
   %4 = view %arg0[%c0][%N] : memref<?xi8> to memref<?xf32>
-  linalg.matvec %2, %3, %4 : (memref<?x?xf32>, memref<?xf32>, memref<?xf32>)
+  linalg.matvec ins(%2, %3: memref<?x?xf32>, memref<?xf32>)
+               outs(%4 : memref<?xf32>)
   return
 }
 // CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
@@ -123,7 +125,8 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
   %1 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
   %2 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
   %3 = view %arg0[%c0][] : memref<?xi8> to memref<f32>
-  linalg.dot %1, %2, %3 : (memref<?xf32>, memref<?xf32>, memref<f32>)
+  linalg.dot ins(%1, %2 : memref<?xf32>, memref<?xf32>)
+            outs(%3 : memref<f32>)
   return
 }
 // CHECKLOOP-LABEL: func @dot(%{{.*}}: memref<?xi8>,
@@ -154,9 +157,9 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
 
 
 func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
-  linalg.dot %arg0, %arg1, %arg2  : (memref<?xf32, offset: ?, strides: [1]>, 
-                                     memref<?xf32, offset: ?, strides: [1]>, 
-                                     memref<f32>)
+  linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
+                                memref<?xf32, offset: ?, strides: [1]>)
+            outs(%arg2:  memref<f32>)
   return
 }
 // CHECKLOOP-LABEL: func @dot_view(
@@ -880,7 +883,8 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
 // Named ops to loops.
 //----------------------------------------------------------------------------//
 func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
-  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                     outs(%C : memref<?x?x?xf32>)
   return
 }
 // CHECKLOOP-LABEL: @named_batch_matmul
@@ -1288,7 +1292,8 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out :  m
 //       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
 
 func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
-  linalg.conv_1d %in, %filter, %out : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+  linalg.conv_1d ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+                outs(%out : memref<?xf32>)
   return
 }
 
@@ -1330,7 +1335,8 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 
 
 func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
-  linalg.conv_2d %in, %filter, %out : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.conv_2d ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+                outs(%out: memref<?x?xf32>)
   return
 }
 // CHECKLOOP-LABEL: @conv2d_no_symbols
@@ -1382,7 +1388,8 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 
 
 func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
-  linalg.conv_3d %in, %filter, %out : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  linalg.conv_3d ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                outs(%out : memref<?x?x?xf32>)
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index ecf8e20e7ef4..7988abbeaf43 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -27,10 +27,10 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
         %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
         %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
         %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
-        linalg.matmul %11, %14, %17 :
-          (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           memref<?x?xf32, offset: ?, strides: [?, 1]>)
+        linalg.matmul
+          ins(%11, %14: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                        memref<?x?xf32, offset: ?, strides: [?, 1]>)
+         outs(%17: memref<?x?xf32, offset: ?, strides: [?, 1]>)
       }
     }
   }
@@ -67,10 +67,7 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] :
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
 //       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) :
 //       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
@@ -103,10 +100,10 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
         %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
         %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
         %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
-        linalg.matmul %11, %14, %17 :
-          (memref<?x?xf64, offset: ?, strides: [?, 1]>,
-           memref<?x?xf64, offset: ?, strides: [?, 1]>,
-           memref<?x?xf64, offset: ?, strides: [?, 1]>)
+        linalg.matmul
+          ins(%11, %14: memref<?x?xf64, offset: ?, strides: [?, 1]>,
+                        memref<?x?xf64, offset: ?, strides: [?, 1]>)
+         outs(%17: memref<?x?xf64, offset: ?, strides: [?, 1]>)
       }
     }
   }
@@ -140,10 +137,7 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] :
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
 //       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) :
 //       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,

diff  --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 270a63cf8609..0f38c904eb5a 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -2,8 +2,9 @@
 
 func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-   linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"}
-     : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+   linalg.matmul {__internal_linalg_transform__ = "START"}
+     ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+    outs(%c: memref<?x?xf32>)
    return
 }
 
@@ -26,7 +27,7 @@ func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:       linalg.copy(%[[T7]], %[[T19]])
 //      CHECK:       linalg.fill(%[[T21]], %[[C42]])
 //      CHECK:       linalg.copy(%[[T17]], %[[T21]])
-//      CHECK:       linalg.matmul %[[T19]], %[[T12]], %[[T21]]
+//      CHECK:       linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]]
 //  CHECK-NOT:       linalg.fill
 //      CHECK:       linalg.copy(%[[T21]], %[[T17]])
 //      CHECK:       dealloc %[[T18]]

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 404c978fa61b..1d58259a578e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -83,30 +83,30 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
           %arg1: memref<?xf32, offset: ?, strides: [1]>,
           %arg2: memref<?xf32, offset: ?, strides: [1]>,
           %arg3: memref<f32>) {
-  linalg.matmul %arg0, %arg0, %arg0 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                                       memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                                       memref<?x?xf32, offset: ?, strides: [?, 1]>)
-  linalg.matvec %arg0, %arg1, %arg2 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                                       memref<?xf32, offset: ?, strides: [1]>,
-                                       memref<?xf32, offset: ?, strides: [1]>) 
-  linalg.dot %arg1, %arg2, %arg3  : (memref<?xf32, offset: ?, strides: [1]>,
-                                     memref<?xf32, offset: ?, strides: [1]>,
-                                     memref<f32>)
+  linalg.matmul ins(%arg0, %arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                                   memref<?x?xf32, offset: ?, strides: [?, 1]>)
+               outs(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matvec ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                                  memref<?xf32, offset: ?, strides: [1]>)
+               outs(%arg2: memref<?xf32, offset: ?, strides: [1]>)
+  linalg.dot ins(%arg1, %arg2: memref<?xf32, offset: ?, strides: [1]>,
+                               memref<?xf32, offset: ?, strides: [1]>)
+            outs(%arg3: memref<f32>)
   return
 }
 // CHECK-LABEL: func @ops(%
-//  CHECK-NEXT:  linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} :
-//  CHECK-SAME:    (memref<?x?xf32, #[[$strided2D]]>,
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>)
-//  CHECK-NEXT:  linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} :
-//  CHECK-SAME:    (memref<?x?xf32, #[[$strided2D]]>,
-//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>,
-//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>)
-//  CHECK-NEXT:  linalg.dot %{{.*}}, %{{.*}}, %{{.*}} :
-//  CHECK-SAME:     (memref<?xf32, #[[$strided1D]]>,
-//  CHECK-SAME:      memref<?xf32, #[[$strided1D]]>,
-//  CHECK-SAME:      memref<f32>)
+// CHECK: linalg.matmul
+// CHECK-SAME:   ins(%{{.*}}, %{{.*}} : memref<?x?xf32, #[[$strided2D]]>,
+// CHECK-SAME:                          memref<?x?xf32, #[[$strided2D]]>)
+// CHECK-SAME:  outs(%{{.*}} : memref<?x?xf32, #[[$strided2D]]>)
+// CHECK: linalg.matvec
+// CHECK-SAME:   ins(%{{.*}}, %{{.*}}: memref<?x?xf32, #[[$strided2D]]>,
+// CHECK-SAME:                         memref<?xf32, #[[$strided1D]]>)
+// CHECK-SAME:  outs(%{{.*}}: memref<?xf32, #[[$strided1D]]>)
+// CHECK: linalg.dot
+// CHECK-SAME:   ins(%{{.*}}, %{{.*}}: memref<?xf32, #[[$strided1D]]>,
+// CHECK-SAME:                         memref<?xf32, #[[$strided1D]]>)
+// CHECK-SAME:  outs(%{{.*}}: memref<f32>)
 
 // -----
 
@@ -619,17 +619,27 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x?x?xf32, #[[$strided3D]]>
 
-
-// TODO: Return tensors need a semantics convention update.
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
-                %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) {
-  linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
-  linalg.batch_matmul %ta3, %tb3, %c3 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, memref<?x?x?xf32>) -> ()
-  return
+                %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+{
+  linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                     outs(%c3: memref<?x?x?xf32>)
+  linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+                     outs(%c3: memref<?x?x?xf32>)
+  %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+                     init(%tc3: tensor<?x?x?xf32>)
+                  -> tensor<?x?x?xf32>
+  %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
+                     init(%tc3: tensor<?x?x?xf32>)
+                  -> tensor<?x?x?xf32>
+  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
 //       CHECK:   linalg.batch_matmul
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.batch_matmul
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
index 60b348110c4f..638fdb885162 100644
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -13,9 +13,9 @@
 func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>,
           %arg1: memref<?xf32, offset: ?, strides: [1]>,
           %arg2: memref<f32>) {
-  linalg.dot %arg0, %arg1, %arg2  : (memref<?xf32, offset: ?, strides: [1]>,
-                                     memref<?xf32, offset: ?, strides: [1]>,
-                                     memref<f32>)
+  linalg.dot ins(%arg0, %arg1: memref<?xf32, offset: ?, strides: [1]>,
+                               memref<?xf32, offset: ?, strides: [1]>)
+             outs(%arg2: memref<f32>)
   return
 }
 // CHECK-LABEL: func @dot(

diff  --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index 08f6d19fe6d6..6ff4be0169fb 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -2,8 +2,9 @@
 
 func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute1"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "distribute1"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c: memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -21,14 +22,15 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:   %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:   %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:   %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
-//      CHECK:   linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:   linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 
 func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute2"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul  {__internal_linalg_transform__ = "distribute2"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c:memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -52,14 +54,15 @@ func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-//      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 
 func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute3"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "distribute3"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c: memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -80,14 +83,15 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
 //      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
 //      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
-//      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 
 func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute4"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "distribute4"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c: memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -108,14 +112,15 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-//      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 
 func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute5"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "distribute5"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c: memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -138,14 +143,15 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:       %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
 //      CHECK:       %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:       %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
-//      CHECK:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:       linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 
 func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-  linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute6"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "distribute6"}
+    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%c: memref<?x?xf32>)
   return
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -165,4 +171,4 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
 //      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
-//      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index 9a1bbfc1dc18..cd20d4f9e537 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -31,10 +31,10 @@
 func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %arg0, %arg1, %arg2 :
-    (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul
+    ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                      memref<?x?xf32, offset: ?, strides: [?, 1]>)
+   outs(%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>)
   return
 }
 // TILE-2-LABEL: func @matmul(
@@ -50,10 +50,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-2:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]]
 //       TILE-2:   %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-2:   linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] :
-//       TILE-2:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-2:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-2:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-2:   linalg.matmul ins(%[[sAi]]{{.*}} outs(%[[sCi]]
 
 // TILE-02-LABEL: func @matmul(
 //       TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -68,10 +65,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-02:   %[[localK:.*]] = dim %{{.*}}, %c1
 //       TILE-02:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]]
 //       TILE-02:   %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-02:   linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-//       TILE-02:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-02:   linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
 
 // TILE-002-LABEL: func @matmul(
 //       TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -86,10 +80,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-002:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]]
 //       TILE-002:   %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
 //       TILE-002:   %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-002:   linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-//       TILE-002:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-002:   linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
 
 // TILE-234-LABEL: func @matmul(
 //       TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -118,10 +109,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:        %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]]
 //       TILE-234:        %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //
-//       TILE-234:        linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-//       TILE-234:          (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-234:        linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
 
 // When the buffer shapes are known at compile time, it is possible to avoid
 // the "min" in subview size computation. This test uses buffer sizes divisible
@@ -130,10 +118,10 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
                     %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>,
                     %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %arg0, %arg1, %arg2 :
-    (memref<10x16xf32, offset: ?, strides: [?, 1]>,
-     memref<16x12xf32, offset: ?, strides: [?, 1]>,
-     memref<10x12xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul
+    ins(%arg0, %arg1: memref<10x16xf32, offset: ?, strides: [?, 1]>,
+                      memref<16x12xf32, offset: ?, strides: [?, 1]>)
+   outs(%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>)
   return
 }
 // TILE-2-LABEL: func @matmul_static(
@@ -148,7 +136,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-2:   %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
 //       TILE-2:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-//       TILE-2:   linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]]
+//       TILE-2:   linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
 
 // TILE-02-LABEL: func @matmul_static(
 //       TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -159,10 +147,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-02:   %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
 //       TILE-02:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
 //       TILE-02:   %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-//       TILE-02:   linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-//       TILE-02:     (memref<10x16xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<16x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<10x?xf32, #[[$strided2D]]>)
+//       TILE-02:   linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
 
 // TILE-002-LABEL: func @matmul_static(
 //       TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -173,10 +158,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-002:   %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
 //       TILE-002:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
 //       TILE-002:   %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-//       TILE-002:   linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-//       TILE-002:     (memref<10x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x12xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<10x12xf32, #[[$strided2D]]>)
+//       TILE-002:   linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
 
 // TILE-234-LABEL: func @matmul_static(
 //       TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -193,16 +175,13 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:        %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //       TILE-234:        %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //
-//       TILE-234:        linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-//       TILE-234:          (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-234:        linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
 
 func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec %arg0, %arg1, %arg2 : (
-    memref<?x?xf32, offset: ?, strides: [?, 1]>, 
-    memref<?xf32, offset: ?, strides: [1]>, 
-    memref<?xf32, offset: ?, strides: [1]>)
+  linalg.matvec
+    ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                      memref<?xf32, offset: ?, strides: [1]>)
+   outs(%arg2: memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // TILE-2-LABEL: func @matvec(
@@ -220,7 +199,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-2:   %[[localN:.*]] = dim %{{.*}}, %c0
 //       TILE-2:   %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]]
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-2:   linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+//       TILE-2:   linalg.matvec ins(%[[sAi]], %{{.*}} outs(%[[sCi]]
 
 // TILE-02-LABEL: func @matvec(
 // TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -237,7 +216,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-02:   %[[localN:.*]] = dim %{{.*}}, %c0
 //       TILE-02:   %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]]
 //       TILE-02:   %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-02:   linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+//       TILE-02:   linalg.matvec ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
 
 // TILE-002-LABEL: func @matvec(
 // TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -268,12 +247,12 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-234:      %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
 //       TILE-234:      %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
 //
-//       TILE-234:      linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+//       TILE-234:      linalg.matvec ins(%[[sAij]], %[[sBj]]{{.*}} outs(%[[sCi]]
 
 func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
-  linalg.dot %arg0, %arg1, %arg2  : (memref<?xf32, offset: ?, strides: [1]>, 
-                                     memref<?xf32, offset: ?, strides: [1]>, 
-                                     memref<f32>)
+  linalg.dot
+    ins(%arg0, %arg1: memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>)
+   outs(%arg2: memref<f32>)
   return
 }
 // TILE-2-LABEL: func @dot(
@@ -287,7 +266,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
 //       TILE-2:   %[[localM:.*]] = dim %{{.*}}, %c0
 //       TILE-2:   %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]]
 //       TILE-2:   %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-2:   linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
+//       TILE-2:   linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
 
 // TILE-02-LABEL: func @dot(
 //   TILE-02-NOT: scf.for
@@ -306,7 +285,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
 //       TILE-234:    %[[localM:.*]] = dim %{{.*}}, %c0
 //       TILE-234:    %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
 //       TILE-234:    %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-234:    linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
+//       TILE-234:    linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
 
 func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
   linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32

diff  --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
index 9d86e5e3f50c..a9733cf5b04e 100644
--- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -6,8 +6,8 @@ func @gemm(%arg0 : memref<?x?xf32>,
            %arg1 : memref<?x?xf32>,
            %arg2 : memref<?x?xf32>)
 {
-  linalg.matmul %arg0, %arg1, %arg2
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%arg2: memref<?x?xf32>)
   return
 }
 // CHECK-LABEL: func @gemm
@@ -21,7 +21,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       CHECK:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
 //       CHECK:       %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
 //       CHECK:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-//       CHECK:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//       CHECK:       linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // TILE1-LABEL: func @gemm
 //   TILE1-DAG:   %[[C2:.*]] = constant 2 : index
@@ -30,7 +30,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       TILE1:     %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //       TILE1:     %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //   TILE1-NOT:     subview
-//       TILE1:     linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]]
+//       TILE1:     linalg.matmul ins(%[[SV1]], %{{.*}} outs(%[[SV3]]
 
 // TILE2-LABEL: func @gemm
 //   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
@@ -40,7 +40,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       TILE2:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //       TILE2:       %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
 //       TILE2:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-//       TILE2:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//       TILE2:       linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 683aeb241318..bc3b7477885a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -5,10 +5,10 @@
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
-  linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} :
-    (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
-     memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
-     memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+  linalg.matmul {__internal_linalg_transform__ = "START"}
+    ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+                memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+   outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
   return
 }
 
@@ -31,7 +31,8 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
   // VECTOR-CONTRACTION: vector.contract
   // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
-  linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref<f32>)
+  linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
+            outs(%C: memref<f32>)
   return
 }
 
@@ -39,8 +40,8 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
 func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
   // VECTOR-CONTRACTION: vector.contract
   // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
-  linalg.matvec %A, %B, %C :
-    (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
+  linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
+            outs(%C: memref<1584xf32>)
   return
 }
 
@@ -48,8 +49,8 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
 func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
   // VECTOR-CONTRACTION: vector.contract
   // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
-  linalg.matmul %A, %B, %C :
-    (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
+  linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+            outs(%C: memref<1584x1584xf32>)
   return
 }
 
@@ -57,7 +58,8 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
 func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
   // VECTOR-CONTRACTION: vector.contract
   // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
-  linalg.batch_matmul %A, %B, %C :
-    (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+  linalg.batch_matmul
+    ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+   outs(%C: memref<1584x1584x1584xf32>)
   return
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 1a4100403b00..6d0039676b04 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,10 +14,11 @@
 func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
           %y: memref<?xf32, offset: ?, strides: [1]>,
           %v: memref<f32>) {
-  linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } :
-             (memref<?xf32, offset: ?, strides: [1]>,
-              memref<?xf32, offset: ?, strides: [1]>,
-              memref<f32>)
+  linalg.dot { __internal_linalg_transform__ = "MEM" }
+    ins(%x, %y: memref<?xf32, offset: ?, strides: [1]>,
+                memref<?xf32, offset: ?, strides: [1]>)
+    outs(%v: memref<f32>)
+
   return
 }
 // CHECK-LABEL: func @dot
@@ -36,10 +37,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
 func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec %A, %x, %y :
-                (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                 memref<?xf32, offset: ?, strides: [1]>,
-                 memref<?xf32, offset: ?, strides: [1]>)
+  linalg.matvec
+    ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?xf32, offset: ?, strides: [1]>)
+    outs(%y: memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @matvec
@@ -48,15 +49,17 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-DAG:     %[[c6:.*]] = constant 6 : index
 // CHECK:         scf.parallel {{.*}} step (%[[c5]])
 // CHECK:           scf.for {{.*}} step %[[c6]]
-// CHECK:             linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK:             linalg.matvec
+// CHECK:               ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK:              outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
 
 func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } :
-    (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul { __internal_linalg_transform__ = "MEM" }
+    ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?x?xf32, offset: ?, strides: [?, 1]>)
+    outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>)
   return
 }
 // CHECK-LABEL: func @matmul
@@ -85,10 +88,9 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:                           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
 // CHECK:                             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
 // CHECK:                               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
-// CHECK:                                 linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                 linalg.matmul
+// CHECK:                                   ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                  outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
 
 #matmul_trait = {
   args_in = 2,
@@ -137,8 +139,9 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
-  linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :
-    (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>)
+  linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"}
+    ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+   outs(%C: memref<8x32xf32>)
   return
 }
 // CHECK-LABEL: func @vectorization_test_2
@@ -236,10 +239,10 @@ func @permute_generic_indexed(
 func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} :
-               (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                memref<?xf32, offset: ?, strides: [1]>,
+  linalg.matvec {__internal_linalg_transform__ = "__with_perm__"}
+    ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                 memref<?xf32, offset: ?, strides: [1]>)
+   outs(%y: memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @matvec_perm
@@ -248,15 +251,17 @@ func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-DAG:     %[[c6:.*]] = constant 6 : index
 // CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
 // CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
-// CHECK:             linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK:             linalg.matvec
+// CHECK:               ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK:              outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
 
 func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} :
-               (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+  linalg.matmul {__internal_linalg_transform__ = "__with_perm__"}
+    ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                 memref<?x?xf32, offset: ?, strides: [?, 1]>)
+   outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>)
   return
 }
 // CHECK-LABEL: func @matmul_perm
@@ -279,10 +284,9 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
 // CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
 // CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
-// CHECK:                                 linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                 linalg.matmul
+// CHECK:                                  ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                   outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
 
 func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                              %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -304,10 +308,10 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
         %5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} :
-                      (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul {__internal_linalg_transform__ = "_promote_views_"}
+          ins(%3, %4: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>)
+         outs(%5: memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -336,8 +340,9 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:               linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK:               linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK:               linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
-// CHECK:               linalg.matmul %[[v0]], %[[v1]], %[[v2]] :
-// CHECK:                 (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+// CHECK:               linalg.matmul
+// CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)
 
 func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                              %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -359,10 +364,10 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
         %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} :
-                      (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"}
+          ins(%3, %4: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>)
+         outs(%5: memref<?x?xf32, offset: ?, strides: [?, ?]>)
       }
     }
   }
@@ -391,10 +396,9 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
 // CHECK:         linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK-NOT:     linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK-NOT:     linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
-// CHECK:         linalg.matmul %[[v0]], %[[s1]], %[[s2]] :
-// CHECK:           (memref<?x?xf32>,
-// CHECK:            memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:            memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:         linalg.matmul
+// CHECK-SAME:           ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK-SAME:          outs(%[[s2]] : memref<?x?xf32, #[[$STRIDED_2D]]>)
 
 func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
   %c2000 = constant 2000 : index
@@ -421,8 +425,9 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
 func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
                                  %arg1: memref<?x?xf32>,
                                  %arg2: memref<?x?xf32>) {
-  linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"}
+    ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+   outs(%arg2: memref<?x?xf32>)
   return
 }
 // CHECK-LABEL: func @tile_permute_parallel_loop

diff  --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
index 731f3872f67d..68ddeb6ad9d0 100644
--- a/mlir/test/IR/slice.mlir
+++ b/mlir/test/IR/slice.mlir
@@ -5,8 +5,10 @@ func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
   %b = alloc(%arg2, %arg1) : memref<?x?xf32>
   %c = alloc(%arg0, %arg1) : memref<?x?xf32>
   %d = alloc(%arg0, %arg1) : memref<?x?xf32>
-  linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
-  linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+               outs(%c : memref<?x?xf32>)
+  linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+               outs(%d : memref<?x?xf32>)
   dealloc %c : memref<?x?xf32>
   dealloc %b : memref<?x?xf32>
   dealloc %a : memref<?x?xf32>

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index d75422e84124..6602060db8dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -308,6 +308,25 @@ parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
     return failure();
   return success();
 }
+static ParseResult
+parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
+                                 Type optOperandType,
+                                 const SmallVectorImpl<Type> &varOperandTypes) {
+  if (parser.parseKeyword("type_refs_capture"))
+    return failure();
+
+  Type operandType2, optOperandType2;
+  SmallVector<Type, 1> varOperandTypes2;
+  if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+                                  varOperandTypes2))
+    return failure();
+
+  if (operandType != operandType2 || optOperandType != optOperandType2 ||
+      varOperandTypes != varOperandTypes2)
+    return failure();
+
+  return success();
+}
 static ParseResult parseCustomDirectiveOperandsAndTypes(
     OpAsmParser &parser, OpAsmParser::OperandType &operand,
     Optional<OpAsmParser::OperandType> &optOperand,
@@ -365,6 +384,14 @@ static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
     printer << ", " << optOperandType;
   printer << " -> (" << varOperandTypes << ")";
 }
+static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+                                             Type operandType,
+                                             Type optOperandType,
+                                             TypeRange varOperandTypes) {
+  printer << " type_refs_capture ";
+  printCustomDirectiveResults(printer, operandType, optOperandType,
+                              varOperandTypes);
+}
 static void
 printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
                                      Value optOperand, OperandRange varOperands,

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ae36ed1710c..2a3f5929f08e 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -53,8 +53,6 @@ def ComplexTensorOp : TEST_Op<"complex_f64_tensor"> {
   let results = (outs TensorOf<[ComplexF64]>);
 }
 
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
 def TupleOp : TEST_Op<"tuple_32_bit"> {
   let results = (outs TupleOf<[I32, F32]>);
 }
@@ -1518,6 +1516,22 @@ def FormatCustomDirectiveResults
   }];
 }
 
+def FormatCustomDirectiveResultsWithTypeRefs
+    : TEST_Op<"format_custom_directive_results_with_type_refs",
+              [AttrSizedResultSegments]> {
+  let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+                      Variadic<AnyType>:$varResults);
+  let assemblyFormat = [{
+    custom<CustomDirectiveResults>(
+      type($result), type($optResult), type($varResults)
+    )
+    custom<CustomDirectiveWithTypeRefs>(
+      type_ref($result), type_ref($optResult), type_ref($varResults)
+    )
+    attr-dict
+  }];
+}
+
 def FormatCustomDirectiveSuccessors
     : TEST_Op<"format_custom_directive_successors", [Terminator]> {
   let successors = (successor AnySuccessor:$successor,

diff  --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
index ba2ea59cb22d..846011a0f488 100644
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -51,7 +51,8 @@ func @dot() -> f32 {
   %B = view %bB[%c0][%c16] : memref<?xi8> to memref<?xf32>
   %C = view %bC[%c0][] : memref<?xi8> to memref<f32>
 
-  linalg.dot %A, %B, %C : (memref<?xf32>, memref<?xf32>, memref<f32>)
+  linalg.dot ins(%A, %B : memref<?xf32>, memref<?xf32>)
+            outs(%C : memref<f32>)
   %res = load %C[] : memref<f32>
 
   dealloc %bC : memref<?xi8>
@@ -83,7 +84,8 @@ func @matmul() -> f32 {
   %B = view %bB[%c0][%c16, %c2] : memref<?xi8> to memref<?x?xf32>
   %C = view %bC[%c0][%c2, %c2] : memref<?xi8> to memref<?x?xf32>
 
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+               outs(%C : memref<?x?xf32>)
   %res = load %C[%c0, %c1] : memref<?x?xf32>
 
   dealloc %bC : memref<?xi8>

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index aad983eb85d2..9183f3a85b48 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -1,9 +1,9 @@
 // RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
 // RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
 
-// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
-//  ODS-NEXT:   NInputs<2>
-//  ODS-NEXT:   NOutputs<1>
+// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [
+//  ODS-NEXT:   NamedStructuredOpTrait
+//  ODS-NEXT:   AttrSizedOperandSegments
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test1Op::iterator_types() {
@@ -25,9 +25,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
   C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
 }
 
-// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
-//  ODS-NEXT:   NInputs<2>
-//  ODS-NEXT:   NOutputs<1>
+// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
+//  ODS-NEXT:   NamedStructuredOpTrait
+//  ODS-NEXT:   AttrSizedOperandSegments
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test2Op::iterator_types() {
@@ -49,9 +49,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
   C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
 }
 
-// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
-//  ODS-NEXT:   NInputs<2>
-//  ODS-NEXT:   NOutputs<1>
+// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
+//  ODS-NEXT:   NamedStructuredOpTrait
+//  ODS-NEXT:   AttrSizedOperandSegments
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test3Op::iterator_types() {

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 60189943ddab..0addff9f35fb 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -230,8 +230,66 @@ def DirectiveTypeZOperandInvalidH : TestFormat_Op<"type_operand_invalid_h", [{
 def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{
   type($result) type($result)
 }]>, Results<(outs I64:$result)>;
+
+//===----------------------------------------------------------------------===//
+// type_ref
+
+// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{
+  type_ref($operand) type(operands)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{
+  type_ref(operands) type($operand)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{
+  type_ref($operand) type($operand)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{
+  type_ref($result) type(results)
+}]>, Results<(outs I64:$result)>;
+// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{
+  type_ref(results) type($result)
+}]>, Results<(outs I64:$result)>;
+// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{
+  type_ref($result) type($result)
+}]>, Results<(outs I64:$result)>;
+
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{
+  type_ref(operands) attr-dict
+}]>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{
+  type(operands) type_ref($operand) attr-dict
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{
+  type($operand) type_ref($operand) attr-dict
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{
+  type(results) type_ref(results) attr-dict
+}]>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{
+  type($result) type_ref(results) attr-dict
+}]>, Results<(outs I64:$result)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{
+  type(results) type_ref($result) attr-dict
+}]>, Results<(outs I64:$result)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{
+  type($result) type_ref($result) attr-dict
+}]>, Results<(outs I64:$result)>;
+
 // CHECK-NOT: error:
-def DirectiveTypeZOperandValid : TestFormat_Op<"type_operand_valid", [{
+def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{
   type(operands) type(results) attr-dict
 }]>;
 

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 9f7c9c0f4809..5066fe5a24e6 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -237,6 +237,12 @@ test.format_custom_directive_results : i64, i64 -> (i64)
 // CHECK: test.format_custom_directive_results : i64 -> (i64)
 test.format_custom_directive_results : i64 -> (i64)
 
+// CHECK: test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64)
+test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
+test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
+
 func @foo() {
   // CHECK: test.format_custom_directive_successors ^bb1, ^bb2
   test.format_custom_directive_successors ^bb1, ^bb2

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 59d655684f48..99b0c03ce521 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -980,7 +980,7 @@ class TCParser {
 
   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
-                StringRef linalgOpName);
+                StringRef linalgOpName, ComprehensionParsingState &state);
 
   /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
@@ -1419,7 +1419,8 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
     return failure();
   }
   if (genODSDecl) {
-    printODS(os, cppOpName, tcName);
+    auto &state = perComprehensionStates.back();
+    printODS(os, cppOpName, tcName, state);
     os << "\n";
   }
   if (genODSImpl) {
@@ -1442,31 +1443,72 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
 
 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
-                        StringRef linalgOpName) {
-  const char *header = R"FMT(  def {0} : LinalgNamedStructured_Op<"{1}", [
-    NInputs<{2}>,
-    NOutputs<{3}>,
+                        StringRef linalgOpName,
+                        ComprehensionParsingState &state) {
+  const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
+    NamedStructuredOpTrait,
+    AttrSizedOperandSegments,
     SingleBlockImplicitTerminator<"YieldOp">]> {
-      let arguments = (ins Variadic<LinalgOperand>:$views);
+      let arguments = (ins Variadic<AnyShaped>:$inputs,
+                           Variadic<AnyMemRef>:$output_buffers,
+                           Variadic<AnyRankedTensor>:$init_tensors);
       let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
-      let regions = (region SizedRegion<1>:$region);
-      let builders = [OpBuilder<
-        "OpBuilder &b, OperationState &result, TypeRange outputTypes, "
-        # "ValueRange views",
+      let regions = (region AnyRegion:$region);
+
+      let builders = [ OpBuilder<
+        "OpBuilder &b, OperationState &result,"
+        "ValueRange inputs, ValueRange outputBuffers",
+        [{{
+          result.addOperands(inputs);
+          result.addOperands(outputBuffers);
+          result.addAttribute(
+            "operand_segment_sizes",
+            b.getI32VectorAttr({{static_cast<int32_t>(inputs.size()),
+                                static_cast<int32_t>(outputBuffers.size()),
+                                static_cast<int32_t>(0)}));
+          buildNamedStructuredOpRegionAndAttributes<{0}>(
+            b,
+            result,
+            TypeRange(inputs),
+            TypeRange(outputBuffers),
+            TypeRange(),
+            TypeRange());
+        }]>, OpBuilder<
+        "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes,"
+        "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors",
         [{{
-          result.addOperands(views);
-          result.addTypes(outputTypes);
+          result.addOperands(inputs);
+          result.addOperands(outputBuffers);
+          result.addOperands(initTensors);
+          result.addTypes(resultTensorTypes);
+          result.addAttribute(
+            "operand_segment_sizes",
+            b.getI32VectorAttr({{static_cast<int32_t>(inputs.size()),
+                                static_cast<int32_t>(outputBuffers.size()),
+                                static_cast<int32_t>(initTensors.size())}));
           buildNamedStructuredOpRegionAndAttributes<{0}>(
-            b, result, TypeRange(views), outputTypes);
+            b,
+            result,
+            TypeRange(inputs),
+            TypeRange(outputBuffers),
+            TypeRange(initTensors),
+            resultTensorTypes);
         }]>
       ];
-      let parser = [{
-        return ::parseNamedStructuredOp<{0}>(parser, result);
-      }];
+      let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
+      let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
+      let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
+      let hasFolder = 1;
+      let hasCanonicalizer = 1;
+
       let extraClassDeclaration = [{{
+        // Auto-generated.
         ArrayAttr iterator_types();
         ArrayAttr indexing_maps();
         static void regionBuilder(Block &block);
+
+        // Generic methods.
+        static unsigned getNumRegionArgs() {{ return {4}; }
         std::string getLibraryCallName() {{
           return generateLibraryCallName(getOperation());
         }
@@ -1481,7 +1523,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       nInputs++;
   }
 
-  os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
+  os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
+                      state.orderedTensorArgs.size());
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -1680,7 +1723,7 @@ int main(int argc, char **argv) {
   }
 
   // Include the proper Linalg header for end-to-end tblgen testing without
-  // resorting to non-portable shgell manipulations.
+  // resorting to non-portable shell manipulations.
   if (testEmitIncludeTdHeader)
     output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
 

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 01877855802d..7658963d6f9b 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -53,6 +53,7 @@ class Element {
     ResultsDirective,
     SuccessorsDirective,
     TypeDirective,
+    TypeRefDirective,
 
     /// This element is a literal.
     Literal,
@@ -230,7 +231,19 @@ class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
   /// The operand that is used to format the directive.
   std::unique_ptr<Element> operand;
 };
-} // end anonymous namespace
+
+/// This class represents the `type_ref` directive.
+class TypeRefDirective
+    : public DirectiveElement<Element::Kind::TypeRefDirective> {
+public:
+  TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
+  Element *getOperand() const { return operand.get(); }
+
+private:
+  /// The operand that is used to format the directive.
+  std::unique_ptr<Element> operand;
+};
+} // namespace
 
 //===----------------------------------------------------------------------===//
 // LiteralElement
@@ -805,6 +818,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
            << llvm::formatv(
                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
                   name);
+  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+    ArgumentLengthKind lengthKind;
+    StringRef name = getTypeListName(dir->getOperand(), lengthKind);
+    // Refer to the previously encountered TypeDirective for name.
+    // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
+    // to properly track the types that will be parsed and pushed later on.
+    if (lengthKind != ArgumentLengthKind::Single)
+      body << "  const ::mlir::SmallVector<::mlir::Type, 1> &" << name
+           << "TypesRef(" << name << "Types);\n";
+    else
+      body << llvm::formatv(
+          "  ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
+          name);
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     ArgumentLengthKind ignored;
     body << "  ::llvm::ArrayRef<::mlir::Type> "
@@ -844,6 +870,15 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
     else
       body << llvm::formatv("{0}Successor", name);
 
+  } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
+    ArgumentLengthKind lengthKind;
+    StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+    if (lengthKind == ArgumentLengthKind::Variadic)
+      body << llvm::formatv("{0}TypesRef", listName);
+    else if (lengthKind == ArgumentLengthKind::Optional)
+      body << llvm::formatv("{0}TypeRef", listName);
+    else
+      body << formatv("{0}RawTypesRef[0]", listName);
   } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -876,6 +911,16 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
             "{0}Operand;\n",
             operand->getVar()->name);
       }
+    } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
+      // Reference to an optional which may or may not have been set.
+      // Retrieve from vector if not empty.
+      ArgumentLengthKind lengthKind;
+      StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+      if (lengthKind == ArgumentLengthKind::Optional)
+        body << llvm::formatv(
+            "    ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
+            "? Type() : {0}TypesRef[0];\n",
+            listName);
     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
       ArgumentLengthKind lengthKind;
       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -907,6 +952,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
       body << llvm::formatv("    if ({0}Operand.hasValue())\n"
                             "      {0}Operands.push_back(*{0}Operand);\n",
                             var->name);
+    } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
+      // In the `type_ref` case, do not parse a new Type that needs to be added.
+      // Just do nothing here.
     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
       ArgumentLengthKind lengthKind;
       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1101,6 +1149,15 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
   } else if (isa<SuccessorsDirective>(element)) {
     body << llvm::formatv(successorListParserCode, "full");
 
+  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+    ArgumentLengthKind lengthKind;
+    StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+    if (lengthKind == ArgumentLengthKind::Variadic)
+      body << llvm::formatv(variadicTypeParserCode, listName);
+    else if (lengthKind == ArgumentLengthKind::Optional)
+      body << llvm::formatv(optionalTypeParserCode, listName);
+    else
+      body << formatv(typeParserCode, listName);
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1431,6 +1488,17 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
     } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
       body << successor->getVar()->name << "()";
 
+    } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
+      auto *typeOperand = dir->getOperand();
+      auto *operand = dyn_cast<OperandVariable>(typeOperand);
+      auto *var = operand ? operand->getVar()
+                          : cast<ResultVariable>(typeOperand)->getVar();
+      if (var->isVariadic())
+        body << var->name << "().getTypes()";
+      else if (var->isOptional())
+        body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+      else
+        body << var->name << "().getType()";
     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
       auto *typeOperand = dir->getOperand();
       auto *operand = dyn_cast<OperandVariable>(typeOperand);
@@ -1604,6 +1672,9 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     body << "  p << ";
     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
+  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+    body << "  p << ";
+    genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     body << "  p.printFunctionalType(";
     genTypeOperandPrinter(dir->getInputs(), body) << ", ";
@@ -1670,6 +1741,7 @@ class Token {
     kw_results,
     kw_successors,
     kw_type,
+    kw_type_ref,
     keyword_end,
 
     // String valued tokens.
@@ -1874,6 +1946,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
           .Case("results", Token::kw_results)
           .Case("successors", Token::kw_successors)
           .Case("type", Token::kw_type)
+          .Case("type_ref", Token::kw_type_ref)
           .Default(Token::identifier);
   return Token(kind, str);
 }
@@ -1994,8 +2067,9 @@ class FormatParser {
   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
                                          llvm::SMLoc loc, bool isTopLevel);
   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                   bool isTopLevel);
-  LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element);
+                                   bool isTopLevel, bool isTypeRef = false);
+  LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
+                                          bool isTypeRef = false);
 
   //===--------------------------------------------------------------------===//
   // Lexer Utilities
@@ -2440,6 +2514,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
     return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_successors:
     return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
+  case Token::kw_type_ref:
+    return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
   case Token::kw_type:
     return parseTypeDirective(element, dirTok, isTopLevel);
 
@@ -2505,7 +2581,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
     return ::mlir::success();
   };
   for (auto &ele : elements) {
-    if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+    if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
+      if (failed(checkTypeOperand(typeEle->getOperand())))
+        return failure();
+    } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
       if (failed(checkTypeOperand(typeEle->getOperand())))
         return ::mlir::failure();
     } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
@@ -2565,7 +2644,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
       // Literals, custom directives, and type directives may be used,
       // but they can't anchor the group.
       .Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
-            OptionalElement, TypeDirective>([&](Element *) {
+            OptionalElement, TypeRefDirective, TypeDirective>([&](Element *) {
         if (isAnchor)
           return emitError(childLoc, "only variables can be used to anchor "
                                      "an optional group");
@@ -2628,6 +2707,13 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
   // After parsing all of the elements, ensure that all type directives refer
   // only to variables.
   for (auto &ele : elements) {
+    if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
+      if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
+        return emitError(curLoc,
+                         "type_ref directives within a custom directive "
+                         "may only refer to variables");
+      }
+    }
     if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
         return emitError(curLoc, "type directives within a custom directive "
@@ -2649,8 +2735,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
     return ::mlir::failure();
 
   // Verify that the element can be placed within a custom directive.
-  if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
-           SuccessorVariable>(parameters.back().get())) {
+  if (!isa<TypeRefDirective, TypeDirective, AttributeVariable, OperandVariable,
+           RegionVariable, SuccessorVariable>(parameters.back().get())) {
     return emitError(childLoc, "only variables and types may be used as "
                                "parameters to a custom directive");
   }
@@ -2727,22 +2813,26 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
 
 LogicalResult
 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                 bool isTopLevel) {
+                                 bool isTopLevel, bool isTypeRef) {
   llvm::SMLoc loc = tok.getLoc();
   if (!isTopLevel)
     return emitError(loc, "'type' is only valid as a top-level directive");
 
   std::unique_ptr<Element> operand;
   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
-      failed(parseTypeDirectiveOperand(operand)) ||
+      failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
     return ::mlir::failure();
-  element = std::make_unique<TypeDirective>(std::move(operand));
+  if (isTypeRef)
+    element = std::make_unique<TypeRefDirective>(std::move(operand));
+  else
+    element = std::make_unique<TypeDirective>(std::move(operand));
   return ::mlir::success();
 }
 
 LogicalResult
-FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element) {
+FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
+                                        bool isTypeRef) {
   llvm::SMLoc loc = curToken.getLoc();
   if (failed(parseElement(element, /*isTopLevel=*/false)))
     return ::mlir::failure();
@@ -2752,23 +2842,36 @@ FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element) {
 
   if (auto *var = dyn_cast<OperandVariable>(element.get())) {
     unsigned opIdx = var->getVar() - op.operand_begin();
-    if (fmt.allOperandTypes || seenOperandTypes.test(opIdx))
+    if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
       return emitError(loc, "'type' of '" + var->getVar()->name +
                                 "' is already bound");
+    if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
+      return emitError(loc, "'type_ref' of '" + var->getVar()->name +
+                                "' is not bound by a prior 'type' directive");
     seenOperandTypes.set(opIdx);
   } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
     unsigned resIdx = var->getVar() - op.result_begin();
-    if (fmt.allResultTypes || seenResultTypes.test(resIdx))
+    if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
       return emitError(loc, "'type' of '" + var->getVar()->name +
                                 "' is already bound");
+    if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
+      return emitError(loc, "'type_ref' of '" + var->getVar()->name +
+                                "' is not bound by a prior 'type' directive");
     seenResultTypes.set(resIdx);
   } else if (isa<OperandsDirective>(&*element)) {
-    if (fmt.allOperandTypes || seenOperandTypes.any())
+    if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
       return emitError(loc, "'operands' 'type' is already bound");
+    if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
+      return emitError(
+          loc,
+          "'operands' 'type_ref' is not bound by a prior 'type' directive");
     fmt.allOperandTypes = true;
   } else if (isa<ResultsDirective>(&*element)) {
-    if (fmt.allResultTypes || seenResultTypes.any())
+    if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
       return emitError(loc, "'results' 'type' is already bound");
+    if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
+      return emitError(
+          loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
     fmt.allResultTypes = true;
   } else {
     return emitError(loc, "invalid argument to 'type' directive");


        


More information about the Mlir-commits mailing list