[Mlir-commits] [mlir] f7f4dd6 - [mlir][spirv] Define `spirv.*DotAccSat` integer dot product ops

Jakub Kuderski llvmlistbot at llvm.org
Tue Dec 6 17:24:22 PST 2022


Author: Jakub Kuderski
Date: 2022-12-06T20:22:48-05:00
New Revision: f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b

URL: https://github.com/llvm/llvm-project/commit/f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b
DIFF: https://github.com/llvm/llvm-project/commit/f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b.diff

LOG: [mlir][spirv] Define `spirv.*DotAccSat` integer dot product ops

This covers `SDotAccSat`, `SUDotAccSat`, and `UDotAccSat`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139243

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/availability.mlir
    mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
    mlir/test/Dialect/SPIRV/IR/target-env.mlir
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
index 451aeb27207da..87f869547b404 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -187,4 +187,154 @@ def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot",
   }];
 }
 
+// -----
+
+def SPIRV_SDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SDotAccSat",
+                                                          [SignedOp]> {
+  let summary = [{
+    Signed integer dot product of Vector 1 and Vector 2 and signed
+    saturating addition of the result with Accumulator.
+  }];
+
+  let description = [{
+    Result Type must be an integer type whose Width must be greater than or
+    equal to that of the components of Vector 1 and Vector 2.
+
+    Vector 1 and Vector 2 must have the same type.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type
+    (enabled by the DotProductInput4x8Bit or DotProductInputAll capability).
+
+    The type of Accumulator must be the same as Result Type.
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of the input vectors are sign-extended to the bit width
+    of the result's type. The sign-extended input vectors are then
+    multiplied component-wise and all components of the vector resulting
+    from the component-wise multiplication are added together. Finally, the
+    resulting sum is added to the input accumulator. This final addition is
+    saturating.
+
+    If any of the multiplications or additions, with the exception of the
+    final accumulation, overflow or underflow, the result of the instruction
+    is undefined.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+    %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+    %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_SUDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SUDotAccSat",
+                                                           [SignedOp,
+                                                            UnsignedOp]> {
+  let summary = [{
+    Mixed-signedness integer dot product of Vector 1 and Vector 2 and signed
+    saturating addition of the result with Accumulator. Components of Vector
+    1 are treated as signed, components of Vector 2 are treated as unsigned.
+  }];
+
+  let description = [{
+    Result Type must be an integer type whose Width must be greater than or
+    equal to that of the components of Vector 1 and Vector 2.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type with
+    the same number of components and same component Width (enabled by the
+    DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1
+    and Vector 2 are vectors, the components of Vector 2 must have a
+    Signedness of 0.
+
+    The type of Accumulator must be the same as Result Type.
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of Vector 1 are sign-extended to the bit width of the
+    result's type. All components of Vector 2 are zero-extended to the bit
+    width of the result's type. The sign- or zero-extended input vectors are
+    then multiplied component-wise and all components of the vector
+    resulting from the component-wise multiplication are added together.
+    Finally, the resulting sum is added to the input accumulator. This final
+    addition is saturating.
+
+    If any of the multiplications or additions, with the exception of the
+    final accumulation, overflow or underflow, the result of the instruction
+    is undefined.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+    %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+    %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_UDotAccSatOp :
+    SPIRV_IntegerDotProductTernaryOp<"UDotAccSat", [UnsignedOp]> {
+  let summary = [{
+    Unsigned integer dot product of Vector 1 and Vector 2 and unsigned
+    saturating addition of the result with Accumulator.
+  }];
+
+  let description = [{
+    Result Type must be an integer type with Signedness of 0 whose Width
+    must be greater than or equal to that of the components of Vector 1 and
+    Vector 2.
+
+    Vector 1 and Vector 2 must have the same type.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type with
+    Signedness of 0 (enabled by the DotProductInput4x8Bit or
+    DotProductInputAll capability).
+
+    The type of Accumulator must be the same as Result Type.
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of the input vectors are zero-extended to the bit width
+    of the result's type. The zero-extended input vectors are then
+    multiplied component-wise and all components of the vector resulting
+    from the component-wise multiplication are added together. Finally, the
+    resulting sum is added to the input accumulator. This final addition is
+    saturating.
+
+    If any of the multiplications or additions, with the exception of the
+    final accumulation, overflow or underflow, the result of the instruction
+    is undefined.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+    %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+    %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+    ```
+  }];
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 888a756be5201..a9e67ad15a956 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4829,6 +4829,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
         "op only supports the 'format' #spirv.packed_vector_format attribute");
 
   Type resultTy = op->getResultTypes().front();
+  bool hasAccumulator = op->getNumOperands() == 3;
+  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
+    return op->emitOpError(
+        "requires the same accumulator operand and result types");
+
   unsigned factorBitWidth = getBitWidth(factorTy);
   unsigned resultBitWidth = getBitWidth(resultTy);
   if (factorBitWidth > resultBitWidth)
@@ -4909,6 +4914,9 @@ getIntegerDotProductCapabilities(Operation *op) {
 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
 
 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 5cd7253da620b..f822ced02c4e3 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -143,3 +143,93 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
   return %r: i64
 }
+
+// CHECK-LABEL: sdot_acc_sat_scalar_i32_i32
+func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.SDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: sdot_acc_sat_vector_4xi8_i64
+func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sdot_acc_sat_vector_4xi16_i64
+func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sudot_acc_sat_scalar_i32_i32
+func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.SUDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: sudot_acc_sat_vector_4xi8_i64
+func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sudot_acc_sat_vector_4xi16_i64
+func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: udot_acc_sat_scalar_i32_i32
+func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.UDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: udot_acc_sat_vector_4xi8_i64
+func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: udot_acc_sat_vector_4xi16_i64
+func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r: i64
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index c0c5cf39b03fd..9e9ae67e0d301 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -142,3 +142,158 @@ func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
   %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
   return %r : i32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sdot_acc_sat_scalar_i32
+func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.SDotAccSat
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sdot_acc_sat_scalar_i64
+func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SDotAccSat
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @sdot_acc_sat_vector_4xi8
+func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.SDotAccSat
+  %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sdot_acc_sat_vector_4xi16
+func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SDotAccSat
+  %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @sdot_acc_sat_vector_8xi8
+func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SDotAccSat
+  %r = spirv.SDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  return %r : i64
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc: i32) -> i32 {
+  // expected-error @+1 {{op requires the same type for both vector operands}}
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i64, i32) -> i32
+  return %r : i32
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc: i16) -> i16 {
+  // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i16) -> i16
+  return %r : i16
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc: i64) -> i64 {
+  // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64, i64) -> i64
+  return %r : i64
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc: i32) -> i64 {
+  // expected-error @+1 {{requires the same accumulator operand and result types}}
+  %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i64
+  return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SUDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sudot_acc_sat_scalar_i32
+func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.SUDotAccSat
+  %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sudot_acc_sat_scalar_i64
+func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SUDotAccSat
+  %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @sudot_acc_sat_vector_4xi8
+func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.SUDotAccSat
+  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sudot_acc_sat_vector_4xi16
+func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SUDotAccSat
+  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @sudot_acc_sat_vector_8xi8
+func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.SUDotAccSat
+  %r = spirv.SUDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @udot_acc_sat_scalar_i32
+func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.UDotAccSat
+  %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @udot_acc_sat_scalar_i64
+func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.UDotAccSat
+  %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @udot_acc_sat_vector_4xi8
+func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+  // CHECK-NEXT: spirv.UDotAccSat
+  %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @udot_acc_sat_vector_4xi16
+func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.UDotAccSat
+  %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+  return %r : i64
+}
+
+// CHECK: @udot_acc_sat_vector_8xi8
+func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+  // CHECK-NEXT: spirv.UDotAccSat
+  %r = spirv.UDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+  return %r : i64
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index 91ffdf26242fc..15e05130fe5db 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -212,6 +212,77 @@ func.func @udot_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>) ->
 } {
   // CHECK: test.convert_to_udot_op
   %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0 : i64
+}
+
+// CHECK-LABEL: @sdot_acc_sat_scalar_i32_i32_capabilities
+func.func @sdot_acc_sat_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.SDotAccSat
+  %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %operand)
+         {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sudot_acc_sat_vector_4xi8_i32_capabilities
+func.func @sudot_acc_sat_vector_4xi8_i32_capabilities(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.SUDotAccSat
+  %0 = "test.convert_to_sudot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability1
+func.func @udot_acc_sat_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_acc_sat_op
+  %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability2
+func.func @udot_acc_sat_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProductInputAll, DotProductInput4x8Bit, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_acc_sat_op
+  %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_capabilities
+func.func @udot_acc_sat_vector_4xi16_i64_capabilities(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.UDotAccSat
+  %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability1
+func.func @udot_acc_sat_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_acc_sat_op
+  %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability2
+func.func @udot_acc_sat_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_acc_sat_op
+  %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
   return %0: i64
 }
 
@@ -304,3 +375,25 @@ func.func @udot_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>) -> i
   %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
   return %0: i64
 }
+
+// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_implied_extension
+func.func @sdot_acc_sat_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+  // Version 1.6 implies SPV_KHR_integer_to_product.
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+    [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.SDotAccSat
+  %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_missing_extension
+func.func @sdot_acc_sat_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+  // Version 1.5 does not imply SPV_KHR_integer_to_product.
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+    [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_sdot_acc_sat_op
+  %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+  return %0: i64
+}

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 13c35ca5b150d..2c117a1a05372 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -223,14 +223,24 @@ void ConvertToTargetEnv::runOnOperation() {
   static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
   static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
   static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
+  static constexpr char sDotAccSatTestOpName[] =
+      "test.convert_to_sdot_acc_sat_op";
+  static constexpr char suDotAccSatTestOpName[] =
+      "test.convert_to_sudot_acc_sat_op";
+  static constexpr char uDotAccSatTestOpName[] =
+      "test.convert_to_udot_acc_sat_op";
 
   RewritePatternSet patterns(context);
-  patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
-               ConvertToGroupNonUniformBallot, ConvertToModule,
-               ConvertToSubgroupBallot,
-               ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
-               ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
-               ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>>(context);
+  patterns.add<
+      ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
+      ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
+      ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
+      ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
+      ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
+      ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
+      ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
+      ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
+      context);
 
   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
     return signalPassFailure();


        


More information about the Mlir-commits mailing list