[llvm-branch-commits] [llvm] AMDGPU: Add basic verification for mfma scale intrinsics (PR #117048)

Matt Arsenault via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 20 12:56:16 PST 2024


https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/117048

Verify the format is valid and the type is one of the expected
i32 vectors. Verify the used vector types at least cover the
requirements of the corresponding format operand.

>From 8b48d1d59f79c456fc09c34100f619d978625be1 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 15 Jul 2024 17:23:00 +0400
Subject: [PATCH] AMDGPU: Add basic verification for mfma scale intrinsics

Verify the format is valid and the type is one of the expected
i32 vectors. Verify the used vector types at least cover the
requirements of the corresponding format operand.
---
 llvm/include/llvm/IR/IntrinsicsAMDGPU.td |  10 +-
 llvm/lib/IR/Verifier.cpp                 |  49 +++++
 llvm/test/Verifier/AMDGPU/mfma-scale.ll  | 230 +++++++++++++++++++++++
 3 files changed, 283 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Verifier/AMDGPU/mfma-scale.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 3a5fc86183ca0e..ee7ea8343eacbf 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2973,12 +2973,10 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
 // blgp.
 //
 // These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
-// and <4 x i32> for f4 formats. If the format control bits imply a
-// smaller type than used, the high elements will be truncated.
-//
-// If the format control bits imply a larger type than used, the high
-// elements are padded with undef.
-
+// and <4 x i32> for f4 formats. It is invalid to use a format that
+// requires more registers than the corresponding vector type (e.g. it
+// is illegal to use <6 x i32> in operand 0 if cbsz specifies an f8
+// format that requires 8 registers).
 class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
   DefaultAttrsIntrinsic<[DestTy],
             [llvm_anyvector_ty, llvm_anyvector_ty, DestTy,
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5c0ccf734cccbc..32b50e199ae8fc 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6383,6 +6383,55 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
         "llvm.amdgcn.s.prefetch.data only supports global or constant memory");
     break;
   }
+  case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
+  case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(0);
+    Value *Src1 = Call.getArgOperand(1);
+
+    uint64_t CBSZ = cast<ConstantInt>(Call.getArgOperand(3))->getZExtValue();
+    uint64_t BLGP = cast<ConstantInt>(Call.getArgOperand(4))->getZExtValue();
+    Check(CBSZ <= 4, "invalid value for cbsz format", Call,
+          Call.getArgOperand(3));
+    Check(BLGP <= 4, "invalid value for blgp format", Call,
+          Call.getArgOperand(4));
+
+    // AMDGPU::MFMAScaleFormats values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 8u;
+      case 2:
+      case 3:
+        return 6u;
+      case 4:
+        return 4u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 4 || NumElts == 6 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 0 must be 4, 6 or 8 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 1 must be 4, 6 or 8 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(CBSZ),
+          "invalid vector type for format", &Call, Src0, Call.getArgOperand(3));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(BLGP),
+          "invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/test/Verifier/AMDGPU/mfma-scale.ll b/llvm/test/Verifier/AMDGPU/mfma-scale.ll
new file mode 100644
index 00000000000000..1e3e8856df3d10
--- /dev/null
+++ b/llvm/test/Verifier/AMDGPU/mfma-scale.ll
@@ -0,0 +1,230 @@
+; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s
+
+; --------------------------------------------------------------------
+; Wrong mangled types
+; --------------------------------------------------------------------
+
+; CHECK: operand 0 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i64> %arg0
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v4i64_fp8__v8i32_fp8(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: operand 1 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i64> %arg1
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8v4i64_fp8(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: operand 0 must be 4, 6 or 8 element i32 vector
+; CHECK:   %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK: <4 x i64> %arg0
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v4i64_fp8__v8i32_fp8(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; CHECK: operand 1 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i64> %arg1
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8v4i64_fp8(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; --------------------------------------------------------------------
+; Impossible vector types
+; --------------------------------------------------------------------
+
+; CHECK: operand 0 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v5i32.v8i32(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <5 x i32> %arg0
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v5i32_fp4__v8i32_fp4(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 4, ; cbsz
+                                                                                      i32 4, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: operand 1 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v5i32(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <5 x i32> %arg1
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v5i32_fp4(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v5i32(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 4, ; cbsz
+                                                                                      i32 4, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: operand 0 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v7i32.v8i32(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <7 x i32> %arg0
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v7i32_fp4__v8i32_fp4(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 4, ; cbsz
+                                                                                      i32 4, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: operand 1 must be 4, 6 or 8 element i32 vector
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v7i32(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <7 x i32> %arg1
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v7i32_fp4(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v7i32(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 4, ; cbsz
+                                                                                      i32 4, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; --------------------------------------------------------------------
+; Out of bounds format
+; --------------------------------------------------------------------
+
+; CHECK: invalid value for cbsz format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 9999, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: i32 9999
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_invalid0__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 9999, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid value for blgp format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 9999, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: i32 9999
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_invalid0(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 9999, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid value for cbsz format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 5, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: i32 5
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_invalid1__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 5, ; cbsz
+                                                                                      i32 2, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid value for blgp format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 5, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: i32 5
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i321_invalid(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 5, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid value for cbsz format
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 5, i32 5, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: i32 5
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_invalid__v8i32_invalid(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 5, ; cbsz
+                                                                                      i32 5, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; --------------------------------------------------------------------
+; Incorrect signature for format cases (IR vector too small)
+; --------------------------------------------------------------------
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i32> %arg0
+; CHECK-NEXT: i32 0
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v4i32_fp8__v8i32_fp8(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i32> %arg1
+; CHECK-NEXT: i32 0
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4__v8i32_fp8___v4i32_fp8(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i32> %arg0
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v4i32_fp8__v4i32_fp8(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <6 x i32> %arg0
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v6i32_fp8__v6i32_fp8(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <4 x i32> %arg0
+; CHECK-NEXT: i32 0
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v4i32_fp8__v4i32_fp8(<4 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+; CHECK: invalid vector type for format
+; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1)
+; CHECK-NEXT: <6 x i32> %arg0
+; CHECK-NEXT: i32 0
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v6i32_fp8__v6i32_fp8(<6 x i32> %arg0, <6 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <16 x float> %arg2,
+                                                                                      i32 0, ; cbsz
+                                                                                      i32 0, ; blgp
+                                                                                      i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}



More information about the llvm-branch-commits mailing list