[Mlir-commits] [mlir] f7f4dd6 - [mlir][spirv] Define `spirv.*DotAccSat` integer dot product ops
Jakub Kuderski
llvmlistbot at llvm.org
Tue Dec 6 17:24:22 PST 2022
Author: Jakub Kuderski
Date: 2022-12-06T20:22:48-05:00
New Revision: f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b
URL: https://github.com/llvm/llvm-project/commit/f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b
DIFF: https://github.com/llvm/llvm-project/commit/f7f4dd6743eac16b26c9bc711ba51ed9e20ed16b.diff
LOG: [mlir][spirv] Define `spirv.*DotAccSat` integer dot product ops
This covers `SDotAccSat`, `SUDotAccSat`, and `UDotAccSat`.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139243
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/availability.mlir
mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
mlir/test/Dialect/SPIRV/IR/target-env.mlir
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
index 451aeb27207da..87f869547b404 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -187,4 +187,154 @@ def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot",
}];
}
+// -----
+
+def SPIRV_SDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SDotAccSat",
+ [SignedOp]> {
+ let summary = [{
+ Signed integer dot product of Vector 1 and Vector 2 and signed
+ saturating addition of the result with Accumulator.
+ }];
+
+ let description = [{
+ Result Type must be an integer type whose Width must be greater than or
+ equal to that of the components of Vector 1 and Vector 2.
+
+ Vector 1 and Vector 2 must have the same type.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type
+ (enabled by the DotProductInput4x8Bit or DotProductInputAll capability).
+
+ The type of Accumulator must be the same as Result Type.
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of the input vectors are sign-extended to the bit width
+ of the result's type. The sign-extended input vectors are then
+ multiplied component-wise and all components of the vector resulting
+ from the component-wise multiplication are added together. Finally, the
+ resulting sum is added to the input accumulator. This final addition is
+ saturating.
+
+ If any of the multiplications or additions, with the exception of the
+ final accumulation, overflow or underflow, the result of the instruction
+ is undefined.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_SUDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SUDotAccSat",
+ [SignedOp,
+ UnsignedOp]> {
+ let summary = [{
+ Mixed-signedness integer dot product of Vector 1 and Vector 2 and signed
+ saturating addition of the result with Accumulator. Components of Vector
+ 1 are treated as signed, components of Vector 2 are treated as unsigned.
+ }];
+
+ let description = [{
+ Result Type must be an integer type whose Width must be greater than or
+ equal to that of the components of Vector 1 and Vector 2.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type with
+ the same number of components and same component Width (enabled by the
+ DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1
+ and Vector 2 are vectors, the components of Vector 2 must have a
+ Signedness of 0.
+
+ The type of Accumulator must be the same as Result Type.
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of Vector 1 are sign-extended to the bit width of the
+ result's type. All components of Vector 2 are zero-extended to the bit
+ width of the result's type. The sign- or zero-extended input vectors are
+ then multiplied component-wise and all components of the vector
+ resulting from the component-wise multiplication are added together.
+ Finally, the resulting sum is added to the input accumulator. This final
+ addition is saturating.
+
+ If any of the multiplications or additions, with the exception of the
+ final accumulation, overflow or underflow, the result of the instruction
+ is undefined.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_UDotAccSatOp :
+ SPIRV_IntegerDotProductTernaryOp<"UDotAccSat", [UnsignedOp]> {
+ let summary = [{
+ Unsigned integer dot product of Vector 1 and Vector 2 and unsigned
+ saturating addition of the result with Accumulator.
+ }];
+
+ let description = [{
+ Result Type must be an integer type with Signedness of 0 whose Width
+ must be greater than or equal to that of the components of Vector 1 and
+ Vector 2.
+
+ Vector 1 and Vector 2 must have the same type.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type with
+ Signedness of 0 (enabled by the DotProductInput4x8Bit or
+ DotProductInputAll capability).
+
+ The type of Accumulator must be the same as Result Type.
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of the input vectors are zero-extended to the bit width
+ of the result's type. The zero-extended input vectors are then
+ multiplied component-wise and all components of the vector resulting
+ from the component-wise multiplication are added together. Finally, the
+ resulting sum is added to the input accumulator. This final addition is
+ saturating.
+
+ If any of the multiplications or additions, with the exception of the
+ final accumulation, overflow or underflow, the result of the instruction
+ is undefined.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ ```
+ }];
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 888a756be5201..a9e67ad15a956 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4829,6 +4829,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
"op only supports the 'format' #spirv.packed_vector_format attribute");
Type resultTy = op->getResultTypes().front();
+ bool hasAccumulator = op->getNumOperands() == 3;
+ if (hasAccumulator && op->getOperand(2).getType() != resultTy)
+ return op->emitOpError(
+ "requires the same accumulator operand and result types");
+
unsigned factorBitWidth = getBitWidth(factorTy);
unsigned resultBitWidth = getBitWidth(resultTy);
if (factorBitWidth > resultBitWidth)
@@ -4909,6 +4914,9 @@ getIntegerDotProductCapabilities(Operation *op) {
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 5cd7253da620b..f822ced02c4e3 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -143,3 +143,93 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
%r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
return %r: i64
}
+
+// CHECK-LABEL: sdot_acc_sat_scalar_i32_i32
+func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.SDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: sdot_acc_sat_vector_4xi8_i64
+func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sdot_acc_sat_vector_4xi16_i64
+func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sudot_acc_sat_scalar_i32_i32
+func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.SUDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: sudot_acc_sat_vector_4xi8_i64
+func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sudot_acc_sat_vector_4xi16_i64
+func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: udot_acc_sat_scalar_i32_i32
+func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.UDotAccSat %a, %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: udot_acc_sat_vector_4xi8_i64
+func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: udot_acc_sat_vector_4xi16_i64
+func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r: i64
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index c0c5cf39b03fd..9e9ae67e0d301 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -142,3 +142,158 @@ func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
%r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
return %r : i32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sdot_acc_sat_scalar_i32
+func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.SDotAccSat
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sdot_acc_sat_scalar_i64
+func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SDotAccSat
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @sdot_acc_sat_vector_4xi8
+func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.SDotAccSat
+ %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sdot_acc_sat_vector_4xi16
+func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SDotAccSat
+ %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @sdot_acc_sat_vector_8xi8
+func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SDotAccSat
+ %r = spirv.SDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+ return %r : i64
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc: i32) -> i32 {
+ // expected-error @+1 {{op requires the same type for both vector operands}}
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i64, i32) -> i32
+ return %r : i32
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc: i16) -> i16 {
+ // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i16) -> i16
+ return %r : i16
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc: i64) -> i64 {
+ // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64, i64) -> i64
+ return %r : i64
+}
+
+// -----
+
+func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc: i32) -> i64 {
+ // expected-error @+1 {{requires the same accumulator operand and result types}}
+ %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i64
+ return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SUDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sudot_acc_sat_scalar_i32
+func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.SUDotAccSat
+ %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sudot_acc_sat_scalar_i64
+func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SUDotAccSat
+ %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @sudot_acc_sat_vector_4xi8
+func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.SUDotAccSat
+ %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sudot_acc_sat_vector_4xi16
+func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SUDotAccSat
+ %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @sudot_acc_sat_vector_8xi8
+func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.SUDotAccSat
+ %r = spirv.SUDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+ return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDotAccSat
+//===----------------------------------------------------------------------===//
+
+// CHECK: @udot_acc_sat_scalar_i32
+func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.UDotAccSat
+ %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @udot_acc_sat_scalar_i64
+func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.UDotAccSat
+ %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @udot_acc_sat_vector_4xi8
+func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 {
+ // CHECK-NEXT: spirv.UDotAccSat
+ %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @udot_acc_sat_vector_4xi16
+func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.UDotAccSat
+ %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64
+ return %r : i64
+}
+
+// CHECK: @udot_acc_sat_vector_8xi8
+func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 {
+ // CHECK-NEXT: spirv.UDotAccSat
+ %r = spirv.UDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64
+ return %r : i64
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index 91ffdf26242fc..15e05130fe5db 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -212,6 +212,77 @@ func.func @udot_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>) ->
} {
// CHECK: test.convert_to_udot_op
%0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0 : i64
+}
+
+// CHECK-LABEL: @sdot_acc_sat_scalar_i32_i32_capabilities
+func.func @sdot_acc_sat_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.SDotAccSat
+ %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %operand)
+ {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sudot_acc_sat_vector_4xi8_i32_capabilities
+func.func @sudot_acc_sat_vector_4xi8_i32_capabilities(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.SUDotAccSat
+ %0 = "test.convert_to_sudot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability1
+func.func @udot_acc_sat_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_acc_sat_op
+ %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability2
+func.func @udot_acc_sat_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>, %acc: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProductInputAll, DotProductInput4x8Bit, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_acc_sat_op
+ %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_capabilities
+func.func @udot_acc_sat_vector_4xi16_i64_capabilities(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.UDotAccSat
+ %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability1
+func.func @udot_acc_sat_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_acc_sat_op
+ %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability2
+func.func @udot_acc_sat_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_acc_sat_op
+ %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
return %0: i64
}
@@ -304,3 +375,25 @@ func.func @udot_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>) -> i
%0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
return %0: i64
}
+
+// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_implied_extension
+func.func @sdot_acc_sat_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+ // Version 1.6 implies SPV_KHR_integer_to_product.
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+ [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.SDotAccSat
+ %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_missing_extension
+func.func @sdot_acc_sat_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes {
+ // Version 1.5 does not imply SPV_KHR_integer_to_product.
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+ [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_sdot_acc_sat_op
+ %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64)
+ return %0: i64
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 13c35ca5b150d..2c117a1a05372 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -223,14 +223,24 @@ void ConvertToTargetEnv::runOnOperation() {
static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
+ static constexpr char sDotAccSatTestOpName[] =
+ "test.convert_to_sdot_acc_sat_op";
+ static constexpr char suDotAccSatTestOpName[] =
+ "test.convert_to_sudot_acc_sat_op";
+ static constexpr char uDotAccSatTestOpName[] =
+ "test.convert_to_udot_acc_sat_op";
RewritePatternSet patterns(context);
- patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
- ConvertToGroupNonUniformBallot, ConvertToModule,
- ConvertToSubgroupBallot,
- ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
- ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
- ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>>(context);
+ patterns.add<
+ ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
+ ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
+ ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
+ ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
+ ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
+ ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
+ ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
+ ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
+ context);
if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list