[clang] [llvm] AMDGPU: Shrink used number of registers for mfma scale based on format (PR #117047)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 09:00:30 PST 2024


https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/117047

>From 975c6c066c577171bf51b4dc5b09384ec2d98540 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Wed, 17 Jul 2024 10:14:05 +0400
Subject: [PATCH] AMDGPU: Shrink used number of registers for mfma scale based
 on format

Currently the builtins assume you are using an 8-bit format that requires
an 8 element vector. We can shrink the number of registers if the format
requires 4 or 6.
---
 .../CodeGenOpenCL/builtins-amdgcn-mfma.cl     |   6 +-
 .../AMDGPU/AMDGPUInstCombineIntrinsic.cpp     |  56 ++++
 .../InstCombine/AMDGPU/mfma-scale.ll          | 312 ++++++++++++++++++
 3 files changed, 372 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/AMDGPU/mfma-scale.ll

diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
index ea9bdcdc211623..b4c9aa64c4df6f 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
@@ -433,14 +433,16 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
 }
 
 // CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
-// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
+// CHECK-GFX950: [[EXTRACT_A:%.+]] = shufflevector <8 x i32> %a, <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> [[EXTRACT_A]], <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
 void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
 {
   *out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
 }
 
 // CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
-// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
+// CHECK-GFX950: [[EXTRACT_A:%.+]] = shufflevector <8 x i32> %a, <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> [[EXTRACT_A]], <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
 void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
 {
   *out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 30b9e3ea11ef4b..087de1bed86f76 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1260,6 +1260,62 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     }
     return std::nullopt;
   }
+  case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
+  case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
+    Value *Src0 = II.getArgOperand(0);
+    Value *Src1 = II.getArgOperand(1);
+    uint64_t CBSZ = cast<ConstantInt>(II.getArgOperand(3))->getZExtValue();
+    uint64_t BLGP = cast<ConstantInt>(II.getArgOperand(4))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case AMDGPU::MFMAScaleFormats::FP6_E2M3:
+      case AMDGPU::MFMAScaleFormats::FP6_E3M2:
+        return 6u;
+      case AMDGPU::MFMAScaleFormats::FP4_E2M1:
+        return 4u;
+      case AMDGPU::MFMAScaleFormats::FP8_E4M3:
+      case AMDGPU::MFMAScaleFormats::FP8_E5M2:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = getFormatNumRegs(CBSZ);
+    unsigned Src1NumElts = getFormatNumRegs(BLGP);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 10> Args(II.args());
+    Args[0] = Src0;
+    Args[1] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {Src0->getType(), Src1->getType()}, Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/mfma-scale.ll b/llvm/test/Transforms/InstCombine/AMDGPU/mfma-scale.ll
new file mode 100644
index 00000000000000..709f143a4745a5
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/mfma-scale.ll
@@ -0,0 +1,312 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=instcombine < %s | FileCheck %s
+
+; --------------------------------------------------------------------
+; Incorrect signature for format cases (IR vector too large) 16x16x128
+; --------------------------------------------------------------------
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_fp6(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> [[ARG0]], <6 x i32> [[TMP1]], <4 x float> [[ARG2]], i32 0, i32 2, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 0, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp8(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp8(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> [[TMP1]], <8 x i32> [[ARG1]], <4 x float> [[ARG2]], i32 2, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 2, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp6(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> [[TMP1]], <6 x i32> [[TMP2]], <4 x float> [[ARG2]], i32 2, i32 2, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 2, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp6__0_scale(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp6__v8i32_fp6__0_scale(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> [[TMP1]], <6 x i32> [[TMP2]], <4 x float> [[ARG2]], i32 2, i32 2, i32 0, i32 0, i32 0, i32 0)
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 2, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 0, i32 0, i32 0)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_fp4(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> [[ARG0]], <4 x i32> [[TMP1]], <4 x float> [[ARG2]], i32 0, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 0, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp8(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp8(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> [[TMP1]], <8 x i32> [[ARG1]], <4 x float> [[ARG2]], i32 4, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 4, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v6i32_fp4(<8 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v6i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <6 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <6 x i32> [[ARG1]], <6 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> [[ARG0]], <4 x i32> [[TMP1]], <4 x float> [[ARG2]], i32 0, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2,
+  i32 0, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v6i32_fp4__v8i32_fp8(<6 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v6i32_fp4__v8i32_fp8(
+; CHECK-SAME: <6 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <6 x i32> [[ARG0]], <6 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> [[TMP1]], <8 x i32> [[ARG1]], <4 x float> [[ARG2]], i32 4, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 4, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp4(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]], <4 x float> [[ARG2]], i32 4, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 4, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp4__0_scale(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2) {
+; CHECK-LABEL: define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v8i32_fp4__0_scale(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]], <4 x float> [[ARG2]], i32 4, i32 4, i32 0, i32 0, i32 0, i32 0)
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 4, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 0, i32 0, i32 0)
+  ret <4 x float> %result
+}
+
+define <4 x float> @test_flags_shrink_src0(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <4 x float> @test_flags_shrink_src0(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <4 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan nsz <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> [[ARG0]], <6 x i32> [[TMP1]], <4 x float> [[ARG2]], i32 0, i32 2, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <4 x float> [[RESULT]]
+;
+  %result = call nnan nsz <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2,
+  i32 0, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <4 x float> %result
+}
+
+; --------------------------------------------------------------------
+; Incorrect signature for format cases (IR vector too large) 32x32x64
+; --------------------------------------------------------------------
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v8i32_fp6(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> [[ARG0]], <6 x i32> [[TMP1]], <16 x float> [[ARG2]], i32 0, i32 2, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 0, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp8(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp8(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> [[TMP1]], <8 x i32> [[ARG1]], <16 x float> [[ARG2]], i32 2, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 2, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp6(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> [[TMP1]], <6 x i32> [[TMP2]], <16 x float> [[ARG2]], i32 2, i32 2, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 2, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp6__0_scale(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp6__v8i32_fp6__0_scale(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> [[TMP1]], <6 x i32> [[TMP2]], <16 x float> [[ARG2]], i32 2, i32 2, i32 0, i32 0, i32 0, i32 0)
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 2, ; cbsz
+  i32 2, ; blgp
+  i32 0, i32 0, i32 0, i32 0)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v8i32_fp4(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v8i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> [[ARG0]], <4 x i32> [[TMP1]], <16 x float> [[ARG2]], i32 0, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 0, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp8(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp8(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> [[TMP1]], <8 x i32> [[ARG1]], <16 x float> [[ARG2]], i32 4, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 4, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v6i32_fp4(<8 x i32> %arg0, <6 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8__v6i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <6 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <6 x i32> [[ARG1]], <6 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> [[ARG0]], <4 x i32> [[TMP1]], <16 x float> [[ARG2]], i32 0, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %arg0, <6 x i32> %arg1, <16 x float> %arg2,
+  i32 0, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v6i32_fp4__v8i32_fp8(<6 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v6i32_fp4__v8i32_fp8(
+; CHECK-SAME: <6 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <6 x i32> [[ARG0]], <6 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> [[TMP1]], <8 x i32> [[ARG1]], <16 x float> [[ARG2]], i32 4, i32 0, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 4, ; cbsz
+  i32 0, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp4(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp4(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]], i32 [[SCALE0:%.*]], i32 [[SCALE1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]], <16 x float> [[ARG2]], i32 4, i32 4, i32 0, i32 [[SCALE0]], i32 0, i32 [[SCALE1]])
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 4, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 %scale0, i32 0, i32 %scale1)
+  ret <16 x float> %result
+}
+
+define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp4__0_scale(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2) {
+; CHECK-LABEL: define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp4__v8i32_fp4__0_scale(
+; CHECK-SAME: <8 x i32> [[ARG0:%.*]], <8 x i32> [[ARG1:%.*]], <16 x float> [[ARG2:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i32> [[ARG0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i32> [[ARG1]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[RESULT:%.*]] = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]], <16 x float> [[ARG2]], i32 4, i32 4, i32 0, i32 0, i32 0, i32 0)
+; CHECK-NEXT:    ret <16 x float> [[RESULT]]
+;
+  %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2,
+  i32 4, ; cbsz
+  i32 4, ; blgp
+  i32 0, i32 0, i32 0, i32 0)
+  ret <16 x float> %result
+}



More information about the llvm-commits mailing list