[llvm] 1a7cf7a - [Verifier] Verify sizes of matrix.multiply operands and specified shape.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 4 12:52:02 PDT 2023


Author: Florian Hahn
Date: 2023-04-04T20:51:43+01:00
New Revision: 1a7cf7a182d61ae81b00aaf5f30f3b471fae3668

URL: https://github.com/llvm/llvm-project/commit/1a7cf7a182d61ae81b00aaf5f30f3b471fae3668
DIFF: https://github.com/llvm/llvm-project/commit/1a7cf7a182d61ae81b00aaf5f30f3b471fae3668.diff

LOG: [Verifier] Verify sizes of matrix.multiply operands and specified shape.

Extend the verifier to check if the size of the matrix operands of
matrix.multiply match the sizes specified by the numeric arguments.

Reviewed By: thegameg

Differential Revision: https://reviews.llvm.org/D147466

Added: 
    

Modified: 
    llvm/lib/IR/Verifier.cpp
    llvm/test/Verifier/matrix-intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5cc90e0de938f..90dd6efdf54a5 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5638,15 +5638,28 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     Type *Op0ElemTy = nullptr;
     Type *Op1ElemTy = nullptr;
     switch (ID) {
-    case Intrinsic::matrix_multiply:
+    case Intrinsic::matrix_multiply: {
       NumRows = cast<ConstantInt>(Call.getArgOperand(2));
+      ConstantInt *N = cast<ConstantInt>(Call.getArgOperand(3));
       NumColumns = cast<ConstantInt>(Call.getArgOperand(4));
+      Check(cast<FixedVectorType>(Call.getArgOperand(0)->getType())
+                    ->getNumElements() ==
+                NumRows->getZExtValue() * N->getZExtValue(),
+            "First argument of a matrix operation does not match specified "
+            "shape!");
+      Check(cast<FixedVectorType>(Call.getArgOperand(1)->getType())
+                    ->getNumElements() ==
+                N->getZExtValue() * NumColumns->getZExtValue(),
+            "Second argument of a matrix operation does not match specified "
+            "shape!");
+
       ResultTy = cast<VectorType>(Call.getType());
       Op0ElemTy =
           cast<VectorType>(Call.getArgOperand(0)->getType())->getElementType();
       Op1ElemTy =
           cast<VectorType>(Call.getArgOperand(1)->getType())->getElementType();
       break;
+    }
     case Intrinsic::matrix_transpose:
       NumRows = cast<ConstantInt>(Call.getArgOperand(1));
       NumColumns = cast<ConstantInt>(Call.getArgOperand(2));

diff  --git a/llvm/test/Verifier/matrix-intrinsics.ll b/llvm/test/Verifier/matrix-intrinsics.ll
index 60b9518bb1172..b6d5ad9a3cc49 100644
--- a/llvm/test/Verifier/matrix-intrinsics.ll
+++ b/llvm/test/Verifier/matrix-intrinsics.ll
@@ -1,4 +1,4 @@
-; RUN: not llvm-as -opaque-pointers < %s -o /dev/null 2>&1 | FileCheck %s
+; RUN: not llvm-as < %s -o /dev/null 2>&1 | FileCheck %s
 
 define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
 ; CHECK: assembly parsed, but does not verify as correct!
@@ -20,17 +20,19 @@ define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
 }
 
 define <4 x float> @multiply(<4 x float> %m, i32 %arg) {
-; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
-; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
+; CHECK-NEXT: First argument of a matrix operation does not match specified shape!
+; CHECK-NEXT: First argument of a matrix operation does not match specified shape!
+; CHECK-NEXT: Second argument of a matrix operation does not match specified shape!
 ; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
 ; CHECK-NEXT: immarg operand has non-immediate parameter
 ; CHECK-NEXT: i32 %arg
-; CHECK-NEXT:   %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
+; CHECK-NEXT:   %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
   %result.0 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %m, i32 0, i32 0, i32 0)
   %result.1 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.0, <4 x float> %m, i32 3, i32 2, i32 2)
   %result.2 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.1, <4 x float> %m, i32 2, i32 2, i32 1)
-  %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
-  ret <4 x float> %result.3
+  %result.3 = call <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 2, i32 2, i32 2)
+  %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1)
+  ret <4 x float> %result.4
 }
 
 define <4 x float> @column.major_load(ptr %m, ptr %n, i32 %arg) {
@@ -136,3 +138,4 @@ declare <4 x float> @llvm.matrix.multiply.v4f32.v4i32.v4f32(<4 x i32>, <4 x floa
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4i32(<4 x float>, <4 x i32>, i32, i32, i32)
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32)
 declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)
+declare <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)


        


More information about the llvm-commits mailing list