[Mlir-commits] [mlir] [mlir][spirv] Update integer dot product op syntax (PR #73468)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 26 16:28:59 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Make the syntax more concise and aligned with the `spirv.Dot` syntax in https://github.com/llvm/llvm-project/pull/73466.

Move some type verification from C++ to ODS.

Regexes to update existing code and tests:
`(\s*\{format\s+=\s+#spirv.packed_vector_format([^}]+)\})`
==>
`, $2`

`(spirv.[SU]+Dot[a-zA-Z]*[^:]+:)(\s*\(([^,]+),[^\)]+\))(.+)`
==>
`$1 $3$4`

---

Patch is 34.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73468.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td (+15-6) 
- (modified) mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp (+2-14) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir (+11-11) 
- (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+18-18) 
- (modified) mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir (+61-76) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
index 87f869547b40496..37a70bdf09e4b0c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -26,10 +26,6 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
     SPIRV_Integer:$result
   );
 
-  let assemblyFormat = [{
-    operands attr-dict `:` `(` type(operands) `)` `->` type($result)
-  }];
-
   // These ops require dynamic availability specification based on operand and
   // result types.
   bit autogenAvailability = 0;
@@ -40,23 +36,36 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
 
 class SPIRV_IntegerDotProductBinaryOp<string mnemonic,
                                       list<Trait> traits = []> :
-      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+      SPIRV_IntegerDotProductOp<mnemonic,
+        !listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>])> {
   let arguments = (ins
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
     OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
   );
+
+  let assemblyFormat = [{
+    $vector1 `,` $vector2 ( `,` $format^ )? attr-dict `:`
+      type($vector1) `->` type($result)
+  }];
 }
 
 class SPIRV_IntegerDotProductTernaryOp<string mnemonic,
                                        list<Trait> traits = []> :
-      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+      SPIRV_IntegerDotProductOp<mnemonic,
+        !listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>,
+                             AllTypesMatch<["accumulator", "result"]>])> {
   let arguments = (ins
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
     SPIRV_Integer:$accumulator,
     OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
   );
+
+  let assemblyFormat = [{
+    $vector1 `,` $vector2 `,` $accumulator ( `,` $format^ )? attr-dict `:`
+      type($vector1) `->` type($result)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 28efe4f046fcde9..00fc2acf7f07d0d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -30,13 +30,10 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
          "Not an integer dot product op?");
   assert(op->getNumResults() == 1 && "Expected a single result");
 
+  // ODS enforces that vector 1 and vector 2, and result and the accumulator
+  // have the same types.
   Type factorTy = op->getOperand(0).getType();
-  if (op->getOperand(1).getType() != factorTy)
-    return op->emitOpError("requires the same type for both vector operands");
-
-  unsigned expectedNumAttrs = 0;
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    ++expectedNumAttrs;
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
             op->getAttr(kPackedVectorFormatAttrName));
@@ -59,16 +56,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
           factorTy));
   }
 
-  if (op->getAttrs().size() > expectedNumAttrs)
-    return op->emitError(
-        "op only supports the 'format' #spirv.packed_vector_format attribute");
-
   Type resultTy = op->getResultTypes().front();
-  bool hasAccumulator = op->getNumOperands() == 3;
-  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
-    return op->emitOpError(
-        "requires the same accumulator operand and result types");
-
   unsigned factorBitWidth = getBitWidth(factorTy);
   unsigned resultBitWidth = getBitWidth(resultTy);
   if (factorBitWidth > resultBitWidth)
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
index e13a51733ec1ee7..989108a7d8d0cfa 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -5,7 +5,7 @@
 
 // CHECK-LABEL: func.func @to_sdot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -17,7 +17,7 @@ func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
 
 // CHECK-LABEL: func.func @to_sdot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -29,7 +29,7 @@ func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
 
 // CHECK-LABEL: func.func @to_sdot_i64
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i64
 //  CHECK-NEXT:   return [[DOT]] : i64
 func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -41,7 +41,7 @@ func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
 
 // CHECK-LABEL: func.func @to_sdot_acc_i64
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i64
 //  CHECK-NEXT:   return [[DOT]] : i64
 func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -53,7 +53,7 @@ func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64)
 
 // CHECK-LABEL: func.func @to_udot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -65,7 +65,7 @@ func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
 
 // CHECK-LABEL: func.func @to_udot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -77,7 +77,7 @@ func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
 
 // CHECK-LABEL: func.func @to_signed_unsigned_dot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -89,7 +89,7 @@ func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
 
 // CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -101,7 +101,7 @@ func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
 
 // CHECK-LABEL: func.func @to_unsigned_signed_dot
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -113,7 +113,7 @@ func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
 
 // CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
 //  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
-//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : vector<4xi8> -> i32
 //  CHECK-NEXT:   return [[DOT]] : i32
 func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
   %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -128,7 +128,7 @@ func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
 //       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 : i8
 //       CHECK:   %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
 //       CHECK:   %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
-//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
+//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : vector<4xi8> -> i32
 //       CHECK:   return %[[SDOT]]
 func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
   %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index f822ced02c4e323..fb95a0c567b9748 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -60,7 +60,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -70,7 +70,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -80,7 +80,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -90,7 +90,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SUDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -100,7 +100,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SUDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -110,7 +110,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SUDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -120,7 +120,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.UDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -130,7 +130,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.UDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -140,7 +140,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -150,7 +150,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -160,7 +160,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -170,7 +170,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -180,7 +180,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.SUDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -190,7 +190,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -200,7 +200,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
 
@@ -210,7 +210,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
-  %r = spirv.UDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
 }
 
@@ -220,7 +220,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
-  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
 }
 
@@ -230,6 +230,6 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
-  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index 8d3c3b85b4887d9..b04e5603019b414 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -10,68 +10,51 @@
 // CHECK: @sdot_scalar_i32
 func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_scalar_i64
 func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_vector_4xi8
 func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  %r = spirv.SDot %a, %b : vector<4xi8> -> i32
   return %r : i32
 }
 
 // CHECK: @sdot_vector_4xi16
 func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+  %r = spirv.SDot %a, %b : vector<4xi16> -> i64
   return %r : i64
 }
 
 // CHECK: @sdot_vector_8xi8
 func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
   // CHECK-NEXT: spirv.SDot
-  %r = spirv.SDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+  %r = spirv.SDot %a, %b : vector<8xi8> -> i64
   return ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73468


More information about the Mlir-commits mailing list