[Mlir-commits] [mlir] 9b72b47 - Revert "[mlir][Linalg] Retire C++ MatmulOp in favor of a linalg-ods-gen'd op."

Kirill Bobyrev llvmlistbot at llvm.org
Tue Jun 16 02:03:01 PDT 2020


Author: Kirill Bobyrev
Date: 2020-06-16T11:02:28+02:00
New Revision: 9b72b47ed63351ee5ceff4c44ccd9a71dc7dad27

URL: https://github.com/llvm/llvm-project/commit/9b72b47ed63351ee5ceff4c44ccd9a71dc7dad27
DIFF: https://github.com/llvm/llvm-project/commit/9b72b47ed63351ee5ceff4c44ccd9a71dc7dad27.diff

LOG: Revert "[mlir][Linalg] Retire C++ MatmulOp in favor of a linalg-ods-gen'd op."

This reverts commit 8c6c49f293fc85e14d811d772bdc9a68464d67b4.

As discussed offline, this patch breaks internal builds and tests so I'm
reverting it for now.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/affine.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-2-level.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/promote.mlir
    mlir/test/Dialect/Linalg/promotion_options.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile.mlir
    mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index c510010acd0b..9f9b53a22011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -1,8 +1,3 @@
-ods_def<MatmulOp>:
-def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
-  C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
-}
-
 ods_def<BatchMatmulOp>:
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1292344484b0..cc7eb5ce9d68 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -225,6 +225,36 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
+def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
+
+  let arguments = (ins AnyStridedMemRefOfRank<2>,
+                       AnyStridedMemRefOfRank<2>,
+                       AnyStridedMemRefOfRank<2>);
+
+  let extraClassDeclaration = libraryCallName # [{
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      return SmallVector<StringRef, 8>{
+        getParallelIteratorTypeName(),
+        getParallelIteratorTypeName(),
+        getReductionIteratorTypeName()};
+    }
+
+    //   A(i, r_k) * B(r_k, j) -> C(i, j)
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      MLIRContext *context = getContext();
+      AffineExpr i, j, r_k;
+      bindDims(context, i, j, r_k);
+      return SmallVector<AffineMap, 8>{
+        AffineMap::get(3, 0, {i, r_k}, context),
+        AffineMap::get(3, 0, {r_k, j},context),
+        AffineMap::get(3, 0, {i, j}, context)
+      };
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
 /// A base class for pooling operation such as conv. The arguments must contain
 /// optional arguments `strides`, `dilations` and `padding` with following type:
 ///   OptionalAttr<I64ArrayAttr>:$strides

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 7b30646136cb..ca6ca8b24732 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -241,11 +241,8 @@ void mlir::populateLinalgToStandardConversionPatterns(
       LinalgOpConversion<FillOp>,
       LinalgOpConversion<GenericOp>,
       LinalgOpConversion<IndexedGenericOp>,
+      LinalgOpConversion<MatmulOp>,
       LinalgOpConversion<MatvecOp>>(ctx);
-  // TODO: collect all auto-generated named ops with a tblgen directive.
-  patterns.insert<
-      LinalgOpConversion<BatchMatmulOp>,
-      LinalgOpConversion<MatmulOp>>(ctx);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fe2f123a00cb..e9bd082fbcae 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1128,6 +1128,10 @@ LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
                              SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1189,7 +1193,7 @@ static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
   p << op.getOperationName() << ' ';
   p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
   p << ' ' << op.getOperands();
-  p << " : (" << op.getOperandTypes() << ")";
+  p << ": (" << op.getOperandTypes() << ")";
   auto outputTensorTypes = op.getResultTypes();
   if (!outputTensorTypes.empty())
     p << " -> (" << outputTensorTypes << ")";
@@ -1201,8 +1205,8 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
   SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
 
   // Optional attributes may be added.
-  if (parser.parseOperandList(operandsInfo) ||
-      parser.parseOptionalAttrDict(result.attributes))
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(operandsInfo))
     return failure();
 
   SmallVector<Type, 8> operandTypes;
@@ -1238,7 +1242,3 @@ LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
                                   SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
-LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 56078a4a6c08..c1080ead1cf9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -268,6 +268,22 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
   }
 };
 
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                       MatmulOp matmulOp) {
+    assert(matmulOp.hasBufferSemantics() &&
+           "expected linalg op with buffer semantics");
+    assert(allIvs.size() == 3);
+    Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
+    IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+        C(matmulOp.getOutputBuffer(0));
+    // Emit scalar form.
+    C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
+  }
+};
+
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, ConvOp> {
 public:
@@ -721,6 +737,7 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)

diff  --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index cb2064602c47..5d20b0b4e0a9 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -15,7 +15,7 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
   %B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
   %C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cb7df05d63e..914c7b7ce345 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -14,8 +14,8 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
   // CHECK:  linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
   %4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
 
-  // CHECK:  linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>)
-  linalg.matmul %3, %3, %3 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  // CHECK:  linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
+  linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return %4: memref<?x?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 44dd268998d2..7be54f45b473 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -12,7 +12,7 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
   %0 = dim %C, %c0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
   scf.for %arg5 = %c0 to %0 step %c20 {
     scf.for %arg6 = %c0 to %2 step %c30 {
       scf.for %arg7 = %c0 to %1 step %c40 {
@@ -28,7 +28,7 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
               %14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               %16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               %17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-              linalg.matmul %14, %16, %17 : (memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
+              linalg.matmul(%14, %16, %17) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
             }
           }
         }

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 7fb5a7ab4e85..da6bd26ef1de 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -14,10 +14,10 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
   %0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, 1]>
   %1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
   %2 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, 1]>,
-     memref<?x?xf32, offset: 0, strides: [?, 1]>,
-     memref<?x?xf32, offset: 0, strides: [?, 1]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, 1]>,
+    memref<?x?xf32, offset: 0, strides: [?, 1]>,
+    memref<?x?xf32, offset: 0, strides: [?, 1]>
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
       scf.for %arg7 = %c0 to %1 step %c4 {
@@ -30,10 +30,10 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
         %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, 1]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -61,10 +61,10 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   %0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -80,10 +80,10 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -113,10 +113,10 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   %0 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -132,10 +132,10 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -165,14 +165,14 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c4 = constant 4 : index
   %c3 = constant 3 : index
   %c2 = constant 2 : index
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
+  linalg.matmul(%A, %B, %D) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   %0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -188,10 +188,10 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -227,14 +227,14 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %0 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %C, %B, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
+  linalg.matmul(%C, %B, %D) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %1 step %c2 {
     scf.for %arg6 = %c0 to %0 step %c3 {
       scf.for %arg7 = %c0 to %2 step %c4 {
@@ -247,10 +247,10 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -275,9 +275,9 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 // CHECK-DAG:    %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
 // CHECK-DAG:    %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
 // CHECK-DAG:    %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK:        linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]]
-// CHECK:        linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]]
-// CHECK:        linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]]
+// CHECK:        linalg.matmul(%[[A_I0]], %[[B_00]], %[[C_I0_]])
+// CHECK:        linalg.matmul(%[[C_I0]], %[[B_0K]], %[[D_IK_]])
+// CHECK:        linalg.matmul(%[[D_IK]], %[[B_KJ]], %[[E_IJ]])
 
 // -----
 
@@ -297,14 +297,14 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c3 = constant 3 : index
   %c2 = constant 2 : index
   %0 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %C, %E :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
+  linalg.matmul(%A, %C, %E) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %1 step %c2 {
@@ -322,10 +322,10 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -359,14 +359,14 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %3 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %4 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %C, %E :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %C, %E) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
       scf.for %arg7 = %c0 to %1 step %c4 {
@@ -379,10 +379,10 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %7, %9, %10 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%7, %9, %10) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -398,10 +398,10 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %7, %9, %10 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%7, %9, %10) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -414,7 +414,7 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 // CHECK:  %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  linalg.matmul %[[A]], %[[C]], %[[E]]
+// CHECK:  linalg.matmul(%[[A]], %[[C]], %[[E]])
 // CHECK:  scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
 // CHECK:    scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
 // CHECK:      scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
@@ -445,14 +445,14 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   %c2 = constant 2 : index
   %0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
-  linalg.matmul %A, %C, %D :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %C, %D) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
   %2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
   scf.for %arg5 = %c0 to %0 step %c2 {
     scf.for %arg6 = %c0 to %2 step %c3 {
@@ -469,10 +469,10 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
         %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %5, %7, %8 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%5, %7, %8) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -742,10 +742,10 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
   %B = alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
   %C = alloc(%dim, %dim)[%s0, %s1]  : memref<?x?xf32, offset: 0, strides: [?, ?]>
 
-  linalg.matmul %A, %B, %C :
-    (memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>,
-     memref<?x?xf32, offset: 0, strides: [?, ?]>)
+  linalg.matmul(%A, %B, %C) :
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>,
+    memref<?x?xf32, offset: 0, strides: [?, ?]>
 
   scf.for %i = %c0 to %dim step %c2 {
     scf.for %j = %c0 to %dim step %c3 {
@@ -759,10 +759,10 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
         %2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
           memref<?x?xf32, offset: 0, strides: [?, ?]> to
           memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %0, %1, %2 :
-          (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>,
-           memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%0, %1, %2) :
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>,
+          memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 020e43da00b4..d1e86ba361c6 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -33,7 +33,7 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
   %B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
   %C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return
 }
 // CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref<?xi8>,

diff  --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index fe36755345ef..c4c5e00c42e7 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -26,10 +26,7 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
         %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
         %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
         %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
-        linalg.matmul %11, %14, %17 :
-          (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           memref<?x?xf32, offset: ?, strides: [?, 1]>)
+        linalg.matmul(%11, %14, %17) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
       }
     }
   }
@@ -63,14 +60,9 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] :
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.matmul(%[[partialA]], %[[partialB]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D_dynamic]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) :
-//       CHECK:           memref<?x?xf32, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf32, #[[$strided2D]]>
+//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[$strided2D_dynamic]]>, memref<?x?xf32, #[[$strided2D]]>
 //
 //       CHECK:         dealloc %[[tmpA]] : memref<32xi8>
 //       CHECK:         dealloc %[[tmpB]] : memref<48xi8>
@@ -96,10 +88,7 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
         %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
         %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
         %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
-        linalg.matmul %11, %14, %17 :
-          (memref<?x?xf64, offset: ?, strides: [?, 1]>,
-           memref<?x?xf64, offset: ?, strides: [?, 1]>,
-           memref<?x?xf64, offset: ?, strides: [?, 1]>)
+        linalg.matmul(%11, %14, %17) : memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>
       }
     }
   }
@@ -133,15 +122,72 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] :
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.matmul(%[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D_dynamic]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
 //
-//       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) :
-//       CHECK:           memref<?x?xf64, #[[$strided2D_dynamic]]>,
-//       CHECK:           memref<?x?xf64, #[[$strided2D]]>
+//       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[$strided2D_dynamic]]>, memref<?x?xf64, #[[$strided2D]]>
 //
 //       CHECK:         dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         dealloc %[[tmpB_f64]] : memref<96xi8>
 //       CHECK:         dealloc %[[tmpC_f64]] : memref<48xi8>
+
+// -----
+
+func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %3 = view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xi32>
+  %4 = view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xi32>
+  %5 = view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xi32>
+  %6 = dim %3, %c0 : memref<?x?xi32>
+  %7 = dim %3, %c1 : memref<?x?xi32>
+  %8 = dim %4, %c1 : memref<?x?xi32>
+  scf.for %arg4 = %c0 to %6 step %c2 {
+    scf.for %arg5 = %c0 to %8 step %c3 {
+      scf.for %arg6 = %c0 to %7 step %c4 {
+        %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+        %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+        %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+        linalg.matmul(%11, %14, %17) : memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @matmul_i32(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
+//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:         %[[vA_i32:.*]] = subview {{.*}} : memref<?x?xi32>
+//       CHECK:         %[[vB_i32:.*]] = subview {{.*}} : memref<?x?xi32>
+//       CHECK:         %[[vC_i32:.*]] = subview {{.*}} : memref<?x?xi32>
+///
+//       CHECK:         %[[tmpA_i32:.*]] = alloc() : memref<32xi8>
+//       CHECK:         %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][{{.*}}][{{.*}}] : memref<32xi8> to memref<?x?xi32>
+//     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xi32>
+//       CHECK:         %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[$strided2D_dynamic]]>
+///
+//       CHECK:         %[[tmpB_i32:.*]] = alloc() : memref<48xi8>
+//       CHECK:         %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xi32>
+//     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xi32>
+//       CHECK:         %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[$strided2D_dynamic]]>
+///
+//       CHECK:         %[[tmpC_i32:.*]] = alloc() : memref<24xi8>
+//       CHECK:         %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xi32>
+//     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xi32>
+//       CHECK:         %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[$strided2D_dynamic]]>
+
+//       CHECK:         linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[$strided2D]]>, memref<?x?xi32, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[$strided2D]]>, memref<?x?xi32, #[[$strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[$strided2D]]>, memref<?x?xi32, #[[$strided2D_dynamic]]>
+//
+//       CHECK:         linalg.matmul(%[[partialA_i32]], %[[partialB_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[$strided2D_dynamic]]>, memref<?x?xi32, #[[$strided2D_dynamic]]>, memref<?x?xi32, #[[$strided2D_dynamic]]>
+//
+//       CHECK:         linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[$strided2D_dynamic]]>, memref<?x?xi32, #[[$strided2D]]>
+//
+//       CHECK:         dealloc %[[tmpA_i32]] : memref<32xi8>
+//       CHECK:         dealloc %[[tmpB_i32]] : memref<48xi8>
+//       CHECK:         dealloc %[[tmpC_i32]] : memref<24xi8>

diff  --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 270a63cf8609..e6c8e2158fc3 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -2,8 +2,8 @@
 
 func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 {
-   linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"}
-     : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+   linalg.matmul(%a, %b, %c) {__internal_linalg_transform__ = "START"}
+     : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
    return
 }
 
@@ -26,7 +26,7 @@ func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:       linalg.copy(%[[T7]], %[[T19]])
 //      CHECK:       linalg.fill(%[[T21]], %[[C42]])
 //      CHECK:       linalg.copy(%[[T17]], %[[T21]])
-//      CHECK:       linalg.matmul %[[T19]], %[[T12]], %[[T21]]
+//      CHECK:       linalg.matmul(%[[T19]], %[[T12]], %[[T21]])
 //  CHECK-NOT:       linalg.fill
 //      CHECK:       linalg.copy(%[[T21]], %[[T17]])
 //      CHECK:       dealloc %[[T18]]

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index aaa2890060e6..6fded85c504a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -83,9 +83,9 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
           %arg1: memref<?xf32, offset: ?, strides: [1]>,
           %arg2: memref<?xf32, offset: ?, strides: [1]>,
           %arg3: memref<f32>) {
-  linalg.matmul %arg0, %arg0, %arg0 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
+  linalg.matmul(%arg0, %arg0, %arg0) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
                                        memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                                       memref<?x?xf32, offset: ?, strides: [?, 1]>)
+                                       memref<?x?xf32, offset: ?, strides: [?, 1]>
   linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
                                        memref<?xf32, offset: ?, strides: [1]>,
                                        memref<?xf32, offset: ?, strides: [1]>
@@ -95,10 +95,10 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
   return
 }
 // CHECK-LABEL: func @ops(%
-//  CHECK-NEXT:  linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} :
-//  CHECK-SAME:    (memref<?x?xf32, #[[$strided2D]]>,
+//  CHECK-NEXT:  linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) :
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>)
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>
 //  CHECK-NEXT:  linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) :
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
 //  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>,

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index a36adf242d63..f55e20fe76c9 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -20,21 +20,12 @@
 // TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
 // TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
 
-//   TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)>
-//  TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)>
-// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)>
-
 //   TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
 //  TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
 // TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
 
-func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-             %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-             %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %arg0, %arg1, %arg2 :
-    (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>)
+func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  linalg.matmul(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // TILE-2-LABEL: func @matmul(
@@ -50,10 +41,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-2:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]]
 //       TILE-2:   %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-2:   linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] :
-//       TILE-2:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-2:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-2:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-2:   linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
 
 // TILE-02-LABEL: func @matmul(
 //       TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -68,10 +56,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-02:   %[[localK:.*]] = dim %{{.*}}, %c1
 //       TILE-02:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]]
 //       TILE-02:   %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-02:   linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-//       TILE-02:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-02:   linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
 
 // TILE-002-LABEL: func @matmul(
 //       TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -86,10 +71,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-002:   %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]]
 //       TILE-002:   %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
 //       TILE-002:   %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-002:   linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-//       TILE-002:     (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-002:   linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
 
 // TILE-234-LABEL: func @matmul(
 //       TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -118,22 +100,14 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:        %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]]
 //       TILE-234:        %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //
-//       TILE-234:        linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-//       TILE-234:          (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-234:        linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
 
 // When the buffer shapes are known at compile time, it is possible to avoid
 // the "min" in subview size computation. This test uses buffer sizes divisible
 // by respective tile sizes (M=10 divisble by 2, N=12 divisible by 2 and 3,
 // K=16 divisble by 2 and 4).
-func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
-                    %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>,
-                    %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %arg0, %arg1, %arg2 :
-    (memref<10x16xf32, offset: ?, strides: [?, 1]>,
-     memref<16x12xf32, offset: ?, strides: [?, 1]>,
-     memref<10x12xf32, offset: ?, strides: [?, 1]>)
+func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>, %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>, %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) {
+  linalg.matmul(%arg0, %arg1, %arg2) : memref<10x16xf32, offset: ?, strides: [?, 1]>, memref<16x12xf32, offset: ?, strides: [?, 1]>, memref<10x12xf32, offset: ?, strides: [?, 1]>
   return
 }
 // TILE-2-LABEL: func @matmul_static(
@@ -144,39 +118,33 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-2-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-2-DAG: %[[M:.*]] = constant 10 : index
 //       TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
-//       TILE-2:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]])
+//       TILE-2:   %[[MIN2:.*]] = affine.min #map2(%[[I]])
 //       TILE-2:   %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
-//       TILE-2:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
+//       TILE-2:   %[[MIN22:.*]] = affine.min #map2(%[[I]])
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-//       TILE-2:   linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]]
+//       TILE-2:   linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]])
 
 // TILE-02-LABEL: func @matmul_static(
 //       TILE-02-DAG: %[[C0:.*]] = constant 0 : index
 //       TILE-02-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-02-DAG: %[[N:.*]] = constant 12 : index
 //       TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
-//       TILE-02:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]])
+//       TILE-02:   %[[MIN2:.*]] = affine.min #map2(%[[J]])
 //       TILE-02:   %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
-//       TILE-02:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
+//       TILE-02:   %[[MIN22:.*]] = affine.min #map2(%[[J]])
 //       TILE-02:   %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-//       TILE-02:   linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-//       TILE-02:     (memref<10x16xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<16x?xf32, #[[$strided2D]]>,
-//       TILE-02:      memref<10x?xf32, #[[$strided2D]]>)
+//       TILE-02:   linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref<10x16xf32, #[[$strided2D]]>, memref<16x?xf32, #[[$strided2D]]>, memref<10x?xf32, #[[$strided2D]]>
 
 // TILE-002-LABEL: func @matmul_static(
 //       TILE-002-DAG: %[[C0:.*]] = constant 0 : index
 //       TILE-002-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-002-DAG: %[[C16:.*]] = constant 16 : index
 //       TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
-//       TILE-002:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]])
+//       TILE-002:   %[[MIN2:.*]] = affine.min #map2(%[[K]])
 //       TILE-002:   %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-//       TILE-002:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
+//       TILE-002:   %[[MIN22:.*]] = affine.min #map2(%[[K]])
 //       TILE-002:   %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-//       TILE-002:   linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-//       TILE-002:     (memref<10x?xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<?x12xf32, #[[$strided2D]]>,
-//       TILE-002:      memref<10x12xf32, #[[$strided2D]]>)
+//       TILE-002:   linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref<10x?xf32, #[[$strided2D]]>, memref<?x12xf32, #[[$strided2D]]>, memref<10x12xf32, #[[$strided2D]]>
 
 // TILE-234-LABEL: func @matmul_static(
 //       TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -193,10 +161,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:        %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //       TILE-234:        %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
 //
-//       TILE-234:        linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-//       TILE-234:          (memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>,
-//       TILE-234:           memref<?x?xf32, #[[$strided2D]]>)
+//       TILE-234:        linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
 
 func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>) {
   linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>

diff  --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
index 9d86e5e3f50c..bfa14570aef1 100644
--- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -6,8 +6,8 @@ func @gemm(%arg0 : memref<?x?xf32>,
            %arg1 : memref<?x?xf32>,
            %arg2 : memref<?x?xf32>)
 {
-  linalg.matmul %arg0, %arg1, %arg2
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul(%arg0, %arg1, %arg2)
+    : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return
 }
 // CHECK-LABEL: func @gemm
@@ -21,7 +21,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       CHECK:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
 //       CHECK:       %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
 //       CHECK:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-//       CHECK:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//       CHECK:       linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])
 
 // TILE1-LABEL: func @gemm
 //   TILE1-DAG:   %[[C2:.*]] = constant 2 : index
@@ -30,7 +30,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       TILE1:     %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //       TILE1:     %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //   TILE1-NOT:     subview
-//       TILE1:     linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]]
+//       TILE1:     linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]])
 
 // TILE2-LABEL: func @gemm
 //   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
@@ -40,7 +40,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
 //       TILE2:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
 //       TILE2:       %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
 //       TILE2:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-//       TILE2:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+//       TILE2:       linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index cf75ee5691d0..73c72ba1c6ef 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -4,10 +4,10 @@
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
-  linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} :
-    (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
-     memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
-     memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+  linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index f2ae0ba76ed0..7ea28a274e05 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -53,10 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } :
-    (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>,
-     memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
+                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // CHECK-LABEL: func @matmul
@@ -85,10 +85,7 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:                           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
 // CHECK:                             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
 // CHECK:                               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
-// CHECK:                                 linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                 linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>
 
 #matmul_trait = {
   args_in = 2,
@@ -120,8 +117,8 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
-  linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :
-    (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>)
+  linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "VECTORIZE"} :
+    memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
   return
 }
 // CHECK-LABEL: func @vectorization_test_2
@@ -219,10 +216,10 @@ func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} :
-               (memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                memref<?x?xf32, offset: ?, strides: [?, 1]>)
+  linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
+               memref<?x?xf32, offset: ?, strides: [?, 1]>,
+               memref<?x?xf32, offset: ?, strides: [?, 1]>,
+               memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // CHECK-LABEL: func @matmul_perm
@@ -245,10 +242,7 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
 // CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
 // CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
-// CHECK:                                 linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:                                   memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:                                 linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>
 
 func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                              %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -270,10 +264,10 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
         %5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} :
-                      (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} :
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -302,8 +296,7 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:               linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK:               linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK:               linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
-// CHECK:               linalg.matmul %[[v0]], %[[v1]], %[[v2]] :
-// CHECK:                 (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+// CHECK:               linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
 
 func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                              %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -325,10 +318,10 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
         %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
              memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-        linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} :
-                      (memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>,
-                       memref<?x?xf32, offset: ?, strides: [?, ?]>)
+        linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} :
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>
       }
     }
   }
@@ -357,10 +350,7 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
 // CHECK:         linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK-NOT:     linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK-NOT:     linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
-// CHECK:         linalg.matmul %[[v0]], %[[s1]], %[[s2]] :
-// CHECK:           (memref<?x?xf32>,
-// CHECK:            memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK:            memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK:         linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>
 
 func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
   %c2000 = constant 2000 : index
@@ -387,8 +377,8 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
 func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
                                  %arg1: memref<?x?xf32>,
                                  %arg2: memref<?x?xf32>) {
-  linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"}
-    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul(%arg0, %arg1, %arg2) {__internal_linalg_transform__ = "par__with_perm__"}
+    : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return
 }
 // CHECK-LABEL: func @tile_permute_parallel_loop

diff  --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
index dd6feb96240e..43641fd1ab7f 100644
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -83,7 +83,7 @@ func @matmul() -> f32 {
   %B = view %bB[%c0][%c16, %c2] : memref<?xi8> to memref<?x?xf32>
   %C = view %bC[%c0][%c2, %c2] : memref<?xi8> to memref<?x?xf32>
 
-  linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   %res = load %C[%c0, %c1] : memref<?x?xf32>
 
   dealloc %bC : memref<?xi8>

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 12e6aeef9162..c417995e67ec 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1474,10 +1474,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           TypeRange inputTypes, TypeRange outputTypes);
 
         static void regionBuilder(Block &block);
-
-        std::string getLibraryCallName() {{
-          return generateLibraryCallName(getOperation());
-        }
       }];
   })FMT";
 


        


More information about the Mlir-commits mailing list