[Mlir-commits] [mlir] 0d691ac - [mlir][spirv] Fix integer dot product format attr validation
Jakub Kuderski
llvmlistbot at llvm.org
Tue Dec 6 20:33:54 PST 2022
Author: Jakub Kuderski
Date: 2022-12-06T23:29:42-05:00
New Revision: 0d691ac4472b5a2e9f9ed3d4a7b91460791efb90
URL: https://github.com/llvm/llvm-project/commit/0d691ac4472b5a2e9f9ed3d4a7b91460791efb90
DIFF: https://github.com/llvm/llvm-project/commit/0d691ac4472b5a2e9f9ed3d4a7b91460791efb90.diff
LOG: [mlir][spirv] Fix integer dot product format attr validation
Do not allow formats for non-scalar vector operands.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139495
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index a9e67ad15a956..52ad8ad5fe7c7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4806,7 +4806,9 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
if (op->getOperand(1).getType() != factorTy)
return op->emitOpError("requires the same type for both vector operands");
+ unsigned expectedNumAttrs = 0;
if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+ ++expectedNumAttrs;
auto packedVectorFormat =
op->getAttr(kPackedVectorFormatAttrName)
.dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
@@ -4816,15 +4818,20 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
assert(packedVectorFormat.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
- "unknown Packed Vector format");
+ "Unknown Packed Vector Format");
if (intTy.getWidth() != 32)
return op->emitOpError(
llvm::formatv("with specified Packed Vector Format ({0}) requires "
"integer vector operands to be 32-bits wide",
packedVectorFormat.getValue()));
+ } else {
+ if (op->hasAttr(kPackedVectorFormatAttrName))
+ return op->emitOpError(llvm::formatv(
+ "with invalid format attribute for vector operands of type '{0}'",
+ factorTy));
}
- if (op->getAttrs().size() > 1)
+ if (op->getAttrs().size() > expectedNumAttrs)
return op->emitError(
"op only supports the 'format' #spirv.packed_vector_format attribute");
diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index 9e9ae67e0d301..8d3c3b85b4887 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -49,6 +49,14 @@ func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
%r = spirv.SDot %a, %b : (i32, i64) -> i32
return %r : i32
}
+// -----
+
+func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+ // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}:
+ (vector<4xi8>, vector<4xi8>) -> i32
+ return %r : i32
+}
// -----
@@ -61,6 +69,14 @@ func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 {
// -----
+func.func @udot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+ // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
+ %r = spirv.UDot %a, %b {volatile = #spirv.decoration<Volatile>}: (vector<4xi8>, vector<4xi8>) -> i32
+ return %r : i32
+}
+
+// -----
+
func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
// expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
%r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i16
More information about the Mlir-commits
mailing list