[Mlir-commits] [mlir] [mlir][spirv] Update integer dot product op syntax (PR #73468)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 26 16:28:29 PST 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/73468

Make the syntax more concise and aligned with the `spirv.Dot` syntax in https://github.com/llvm/llvm-project/pull/73466.

Move some type verification from C++ to ODS.

Regexes to update existing code and tests:
`(\s*\{format\s+=\s+#spirv.packed_vector_format([^}]+)\})`
==>
`, $2`

`(spirv.[SU]+Dot[a-zA-Z]*[^:]+:)(\s*\(([^,]+),[^\)]+\))(.+)`
==>
`$1 $3$4`

>From 200a9dd28e77b0a6f9e4ac0f10cdd4ec80412278 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 18:46:07 -0500
Subject: [PATCH] [mlir][spirv] Upate integer dot product op syntax

Make the syntax more concise and aligned with the `spirv.Dot` syntax in
https://github.com/llvm/llvm-project/pull/73466.

Move some type verification from C++ to ODS.

Regexes to update existing code and tests:
`(\s*\{format\s+=\s+#spirv.packed_vector_format([^}]+)\})`
==>
`, $2`

`(spirv.[SU]+Dot[a-zA-Z]*[^:]+:)(\s*\(([^,]+),[^\)]+\))(.+)`
==>
`$1 $3$4`
---
 .../SPIRV/IR/SPIRVIntegerDotProductOps.td     |  21 ++-
 .../Dialect/SPIRV/IR/IntegerDotProductOps.cpp |  16 +-
 .../vector-reduction-to-spirv-dot-prod.mlir   |  22 +--
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  36 ++---
 .../SPIRV/IR/integer-dot-product-ops.mlir     | 137 ++++++++----------
 5 files changed, 107 insertions(+), 125 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
index 87f869547b40496..37a70bdf09e4b0c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -26,10 +26,6 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
     SPIRV_Integer:$result
   );
 
-  let assemblyFormat = [{
-    operands attr-dict `:` `(` type(operands) `)` `->` type($result)
-  }];
-
   // These ops require dynamic availability specification based on operand and
   // result types.
   bit autogenAvailability = 0;
@@ -40,23 +36,36 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
 
 class SPIRV_IntegerDotProductBinaryOp<string mnemonic,
                                       list<Trait> traits = []> :
-      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+      SPIRV_IntegerDotProductOp<mnemonic,
+        !listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>])> {
   let arguments = (ins
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
     OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
   );
+
+  let assemblyFormat = [{
+    $vector1 `,` $vector2 ( `,` $format^ )? attr-dict `:`
+      type($vector1) `->` type($result)
+  }];
 }
 
 class SPIRV_IntegerDotProductTernaryOp<string mnemonic,
                                        list<Trait> traits = []> :
-      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+      SPIRV_IntegerDotProductOp<mnemonic,
+        !listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>,
+                             AllTypesMatch<["accumulator", "result"]>])> {
   let arguments = (ins
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
     SPIRV_Integer:$accumulator,
     OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
   );
+
+  let assemblyFormat = [{
+    $vector1 `,` $vector2 `,` $accumulator ( `,` $format^ )? attr-dict `:`
+      type($vector1) `->` type($result)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 28efe4f046fcde9..00fc2acf7f07d0d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -30,13 +30,10 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
          "Not an integer dot product op?");
   assert(op->getNumResults() == 1 && "Expected a single result");
 
+  // ODS enforces that vector 1 and vector 2, and result and the accumulator
+  // have the same types.
   Type factorTy = op->getOperand(0).getType();
-  if (op->getOperand(1).getType() != factorTy)
-    return op->emitOpError("requires the same type for both vector operands");
-
-  unsigned expectedNumAttrs = 0;
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    ++expectedNumAttrs;
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
             op->getAttr(kPackedVectorFormatAttrName));
@@ -59,16 +56,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
           factorTy));
   }
 
-  if (op->getAttrs().size() > expectedNumAttrs)
-    return op->emitError(
-        "op only supports the 'format' #spirv.packed_vector_format attribute");
-
   Type resultTy = op->getResultTypes().front();
-  bool hasAccumulator = op->getNumOperands() == 3;
-  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
-    return op->emitOpError(
-        "requires the same accumulator operand and result types");
-
   unsigned factorBitWidth = getBitWidth(factorTy);
   unsigned resultBitWidth = getBitWidth(resultTy);
   if (factorBitWidth > resultBitWidth)
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
index e13a51733ec1ee7..989108a7d8d0cfa 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -5,7 +5,7 @@
 
 // CHECK-LABEL: func.func @to_sdot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -17,7 +17,7 @@ func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
 
 // CHECK-LABEL: func.func @to_sdot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -29,7 +29,7 @@ func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
 
 // CHECK-LABEL: func.func @to_sdot_i64
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i64
 //  CHECK-NEXT:   return [[DOT]] : i64
 func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -41,7 +41,7 @@ func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
 
 // CHECK-LABEL: func.func @to_sdot_acc_i64
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i64
 //  CHECK-NEXT:   return [[DOT]] : i64
 func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -53,7 +53,7 @@ func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64)
 
 // CHECK-LABEL: func.func @to_udot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -65,7 +65,7 @@ func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
 
 // CHECK-LABEL: func.func @to_udot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -77,7 +77,7 @@ func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
 
 // CHECK-LABEL: func.func @to_signed_unsigned_dot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -89,7 +89,7 @@ func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
 
 // CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -101,7 +101,7 @@ func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
 
 // CHECK-LABEL: func.func @to_unsigned_signed_dot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -113,7 +113,7 @@ func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
 
 // CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -128,7 +128,7 @@ func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
 //       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 : i8
 //       CHECK:   %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
 //       CHECK:   %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
-//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
+//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : vector<4xi8> -> i32
 //       CHECK:   return %[[SDOT]]
 func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index f822ced02c4e323..fb95a0c567b9748 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -60,7 +60,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -70,7 +70,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -80,7 +80,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -90,7 +90,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SUDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -100,7 +100,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SUDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -110,7 +110,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SUDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -120,7 +120,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.UDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -130,7 +130,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.UDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -140,7 +140,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -150,7 +150,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -160,7 +160,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -170,7 +170,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -180,7 +180,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SUDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -190,7 +190,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -200,7 +200,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -210,7 +210,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.UDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -220,7 +220,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -230,6 +230,6 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index 8d3c3b85b4887d9..b04e5603019b414 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -10,68 +10,51 @@
 // CHECK: @sdot_scalar_i32
 func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_scalar_i64
 func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_vector_4xi8
 func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  %r = spirv.SDot %a, %b : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_vector_4xi16
 func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SDot %a, %b : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_vector_8xi8
 func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+  %r = spirv.SDot %a, %b : vector<8xi8> -> i64
   return %r : i64
 }
 
 // -----
 
+// expected-note @+1 {{prior use here}}
 func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
-  // expected-error @+1 {{op requires the same type for both vector operands}}
-  %r = spirv.SDot %a, %b : (i32, i64) -> i32
+  // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
+  %r = spirv.SDot %a, %b : i32 -> i32
   return %r : i32
 }
 // -----
 
 func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}:
-        (vector<4xi8>, vector<4xi8>) -> i32
-  return %r : i32
-}
-
-// -----
-
-func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 {
-  // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
-  %r = spirv.SDot %a, %b {volatile = #spirv.decoration<Volatile>,
-                          format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
-  return %r : i32
-}
-
-// -----
-
-func.func @udot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
-  // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
-  %r = spirv.UDot %a, %b {volatile = #spirv.decoration<Volatile>}: (vector<4xi8>, vector<4xi8>) -> i32
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : vector<4xi8> -> i32
   return %r : i32
 }
 
@@ -79,7 +62,7 @@ func.func @udot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32
 
 func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
   // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i16
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i16
   return %r : i16
 }
 
@@ -87,7 +70,7 @@ func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
 
 func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
   // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64) -> i64
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i64 -> i64
   return %r : i64
 }
 
@@ -100,35 +83,35 @@ func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
 // CHECK: @sudot_scalar_i32
 func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
   // CHECK-NEXT: spirv.SUDot
-  %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @sudot_scalar_i64
 func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
   // CHECK-NEXT: spirv.SUDot
-  %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @sudot_vector_4xi8
 func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   // CHECK-NEXT: spirv.SUDot
-  %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  %r = spirv.SUDot %a, %b : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @sudot_vector_4xi16
 func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
   // CHECK-NEXT: spirv.SUDot
-  %r = spirv.SUDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SUDot %a, %b : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @sudot_vector_8xi8
 func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
   // CHECK-NEXT: spirv.SUDot
-  %r = spirv.SUDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+  %r = spirv.SUDot %a, %b : vector<8xi8> -> i64
   return %r : i64
 }
 
@@ -141,21 +124,21 @@ func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
 // CHECK: @udot_scalar_i32
 func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
   // CHECK-NEXT: spirv.UDot
-  %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @udot_scalar_i64
 func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
   // CHECK-NEXT: spirv.UDot
-  %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @udot_vector_4xi8
 func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   // CHECK-NEXT: spirv.UDot
-  %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  %r = spirv.UDot %a, %b : vector<4xi8> -> i32
   return %r : i32
 }
 
@@ -166,69 +149,71 @@ func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
 //===----------------------------------------------------------------------===//
 
 // CHECK: @sdot_acc_sat_scalar_i32
-func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.SDotAccSat
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_acc_sat_scalar_i64
-func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SDotAccSat
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_acc_sat_vector_4xi8
-func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.SDotAccSat
-  %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_acc_sat_vector_4xi16
-func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SDotAccSat
-  %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_acc_sat_vector_8xi8
-func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SDotAccSat
-  %r = spirv.SDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %b, %acc : vector<8xi8> -> i64
   return %r : i64
 }
 
 // -----
 
-func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc: i32) -> i32 {
-  // expected-error @+1 {{op requires the same type for both vector operands}}
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i64, i32) -> i32
+// expected-note @+1 {{prior use here}}
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc : i32) -> i32 {
+  // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // -----
 
-func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc: i16) -> i16 {
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc : i16) -> i16 {
   // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i16) -> i16
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i16
   return %r : i16
 }
 
 // -----
 
-func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc: i64) -> i64 {
+func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc : i64) -> i64 {
   // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64, i64) -> i64
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i64 -> i64
   return %r : i64
 }
 
 // -----
 
-func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc: i32) -> i64 {
-  // expected-error @+1 {{requires the same accumulator operand and result types}}
-  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i64
+// expected-note @+1 {{prior use here}}
+func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc : i32) -> i64 {
+  // expected-error @+1 {{use of value '%acc' expects different type than prior uses: 'i64' vs 'i32'}}
+  %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
@@ -239,37 +224,37 @@ func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc: i32) -> i
 //===----------------------------------------------------------------------===//
 
 // CHECK: @sudot_acc_sat_scalar_i32
-func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.SUDotAccSat
-  %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @sudot_acc_sat_scalar_i64
-func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SUDotAccSat
-  %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @sudot_acc_sat_vector_4xi8
-func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.SUDotAccSat
-  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @sudot_acc_sat_vector_4xi16
-func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SUDotAccSat
-  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @sudot_acc_sat_vector_8xi8
-func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.SUDotAccSat
-  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %b, %acc : vector<8xi8> -> i64
   return %r : i64
 }
 
@@ -280,36 +265,36 @@ func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i
 //===----------------------------------------------------------------------===//
 
 // CHECK: @udot_acc_sat_scalar_i32
-func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.UDotAccSat
-  %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @udot_acc_sat_scalar_i64
-func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.UDotAccSat
-  %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @udot_acc_sat_vector_4xi8
-func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
   // CHECK-NEXT: spirv.UDotAccSat
-  %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @udot_acc_sat_vector_4xi16
-func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.UDotAccSat
-  %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @udot_acc_sat_vector_8xi8
-func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
   // CHECK-NEXT: spirv.UDotAccSat
-  %r = spirv.UDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %b, %acc : vector<8xi8> -> i64
   return %r : i64
 }



More information about the Mlir-commits mailing list