[Mlir-commits] [mlir] 0d691ac - [mlir][spirv] Fix integer dot product format attr validation

Jakub Kuderski llvmlistbot at llvm.org
Tue Dec 6 20:33:54 PST 2022


Author: Jakub Kuderski
Date: 2022-12-06T23:29:42-05:00
New Revision: 0d691ac4472b5a2e9f9ed3d4a7b91460791efb90

URL: https://github.com/llvm/llvm-project/commit/0d691ac4472b5a2e9f9ed3d4a7b91460791efb90
DIFF: https://github.com/llvm/llvm-project/commit/0d691ac4472b5a2e9f9ed3d4a7b91460791efb90.diff

LOG: [mlir][spirv] Fix integer dot product format attr validation

Do not allow formats for non-scalar vector operands.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139495

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index a9e67ad15a956..52ad8ad5fe7c7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4806,7 +4806,9 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   if (op->getOperand(1).getType() != factorTy)
     return op->emitOpError("requires the same type for both vector operands");
 
+  unsigned expectedNumAttrs = 0;
   if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+    ++expectedNumAttrs;
     auto packedVectorFormat =
         op->getAttr(kPackedVectorFormatAttrName)
             .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
@@ -4816,15 +4818,20 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
 
     assert(packedVectorFormat.getValue() ==
                spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
-           "unknown Packed Vector format");
+           "Unknown Packed Vector Format");
     if (intTy.getWidth() != 32)
       return op->emitOpError(
           llvm::formatv("with specified Packed Vector Format ({0}) requires "
                         "integer vector operands to be 32-bits wide",
                         packedVectorFormat.getValue()));
+  } else {
+    if (op->hasAttr(kPackedVectorFormatAttrName))
+      return op->emitOpError(llvm::formatv(
+          "with invalid format attribute for vector operands of type '{0}'",
+          factorTy));
   }
 
-  if (op->getAttrs().size() > 1)
+  if (op->getAttrs().size() > expectedNumAttrs)
     return op->emitError(
         "op only supports the 'format' #spirv.packed_vector_format attribute");
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
index 9e9ae67e0d301..8d3c3b85b4887 100644
--- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -49,6 +49,14 @@ func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
   %r = spirv.SDot %a, %b : (i32, i64) -> i32
   return %r : i32
 }
+// -----
+
+func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+  // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
+  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}:
+        (vector<4xi8>, vector<4xi8>) -> i32
+  return %r : i32
+}
 
 // -----
 
@@ -61,6 +69,14 @@ func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 {
 
 // -----
 
+func.func @udot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+  // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
+  %r = spirv.UDot %a, %b {volatile = #spirv.decoration<Volatile>}: (vector<4xi8>, vector<4xi8>) -> i32
+  return %r : i32
+}
+
+// -----
+
 func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
   // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
   %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i16


        


More information about the Mlir-commits mailing list