[llvm] AMDGPU: Support intrinsic selection for gfx1250 wmma instructions (PR #148957)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 15 13:51:24 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Changpeng Fang (changpeng)

<details>
<summary>Changes</summary>



---

Patch is 375.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148957.diff


20 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+169) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPU.td (+14) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+12) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+77) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+93) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h (+9) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+26-3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+39) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td (+4) 
- (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+38) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+2) 
- (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+21) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+67-22) 
- (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+240) 
- (added) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+840) 
- (added) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+2238) 
- (added) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+1993) 
- (added) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.index.gfx1250.w32.ll (+690) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 962693003349e..acd10af1709ea 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2919,6 +2919,20 @@ def int_amdgcn_permlanex16_var : ClangBuiltin<"__builtin_amdgcn_permlanex16_var"
 // the form: D = A * B + C.
 // A is sparse matrix, half the size of B, and is expanded using sparsity index.
 
+class AMDGPUSWmmacIntrinsicIdxReuse<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
+  Intrinsic<
+    [CD],               // %D
+    [
+      A,                // %A
+      B,                // %B
+      LLVMMatchType<0>, // %C
+      Index,            // %Sparsity index for A
+      llvm_i1_ty,       // matrix_a_reuse
+      llvm_i1_ty,       // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, IntrWillReturn, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]
+>;
+
 class AMDGPUSWmmacIntrinsicIdx<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
   Intrinsic<
     [CD],               // %D
@@ -3602,6 +3616,161 @@ def int_amdgcn_fdiv_fast : DefaultAttrsIntrinsic<
   [IntrNoMem, IntrSpeculatable]
 >;
 
+// WMMA intrinsics.
+class AMDGPUWmmaIntrinsicModsAB<LLVMType AB, LLVMType CD> :
+  Intrinsic<
+    [CD], // %D
+    [
+      llvm_i1_ty,       // %A_mod: 0 -- none, 1 -- neg
+      AB,               // %A
+      llvm_i1_ty,       // %B_mod: 0 -- none, 1 -- neg
+      LLVMMatchType<1>, // %B
+      LLVMMatchType<0>,               // %C
+      llvm_i1_ty,       // matrix_a_reuse
+      llvm_i1_ty,       // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<6>>,
+     IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsC<LLVMType AB, LLVMType CD> :
+  Intrinsic<
+    [CD], // %D
+    [
+      AB,               // %A
+      LLVMMatchType<1>, // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>,               // %C
+      llvm_i1_ty,       // matrix_a_reuse
+      llvm_i1_ty,       // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>,
+     IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicF4ModsC<LLVMType A, LLVMType B, LLVMType CD> :
+  Intrinsic<
+    [CD], // %D
+    [
+      A,                // %A
+      B,                // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>,               // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<2>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsAll<LLVMType AB, LLVMType CD> :
+  Intrinsic<
+    [CD], // %D
+    [
+      llvm_i1_ty,       // %A_mod: 0 -- none, 1 -- neg
+      AB,               // %A
+      llvm_i1_ty,       // %B_mod: 0 -- none, 1 -- neg
+      LLVMMatchType<1>, // %B
+      llvm_i16_ty,      // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+      LLVMMatchType<0>,               // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsAllReuse<LLVMType AB, LLVMType CD> :
+  Intrinsic<
+    [CD], // %D
+    [
+      llvm_i1_ty,       // %A_mod: 0 -- none, 1 -- neg
+      AB,               // %A
+      llvm_i1_ty,       // %B_mod: 0 -- none, 1 -- neg
+      LLVMMatchType<1>, // %B
+      llvm_i16_ty,      // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+      LLVMMatchType<0>,               // %C
+      llvm_i1_ty,       // matrix_a_reuse
+      llvm_i1_ty,       // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>,
+     IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+// D and C are of different types.
+class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
+  Intrinsic<
+    [DstTy],     // %D
+    [
+      llvm_i1_ty,        // %A_mod: 0 -- none, 1 -- neg
+      AB,                // %A
+      llvm_i1_ty,        // %B_mod: 0 -- none, 1 -- neg
+      LLVMMatchType<1>,  // %B
+      llvm_i16_ty,       // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+      C,                 // %C
+      llvm_i1_ty,        // matrix_a_reuse
+      llvm_i1_ty,        // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>,
+     IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
+def int_amdgcn_wmma_f64_16x16x4_f64       : AMDGPUWmmaIntrinsicModsAll<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16_16x16x32_bf16    : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_fp8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_fp8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_bf8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_bf8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_fp8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_fp8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_bf8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_bf8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_i32_16x16x64_iu8      : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_32x16x128_f4       : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
+}
+
+class AMDGPUSWmmacIntrinsicABIdx<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
+  Intrinsic<
+    [CD], // %D
+    [
+      llvm_i1_ty,       // %A_mod:  0 - none, 1 - neg
+      A,                // %A
+      llvm_i1_ty,       // %B_mod:  0 - none, 1 - neg
+      B,                // %B
+      LLVMMatchType<0>, // %C
+      Index,            // %Sparsity index for A
+      llvm_i1_ty,       // matrix_a_reuse
+      llvm_i1_ty,       // matrix_b_reuse
+    ],
+    [IntrNoMem, IntrConvergent, IntrWillReturn, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>]
+>;
+
+defset list<Intrinsic> AMDGPUSWMMACIntrinsicsGFX1250 = {
+def int_amdgcn_swmmac_f32_16x16x64_f16      : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x64_bf16     : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x64_f16      : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_bf16_16x16x64_bf16    : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_bf16f32_16x16x64_bf16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_fp8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_fp8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_bf8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_bf8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_fp8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_fp8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_bf8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_bf8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_i32_16x16x128_iu8     : AMDGPUSWmmacIntrinsicABIdx<llvm_anyint_ty, llvm_anyint_ty, llvm_anyint_ty, llvm_anyint_ty>;
+}
+
+
 class AMDGPUTensorLoadStore:
   Intrinsic<
     [],
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 04ec1716478e4..712dcadf2d317 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -844,6 +844,12 @@ def FeatureCvtFP8VOP1Bug : SubtargetFeature<"cvt-fp8-vop1-bug",
   [FeatureFP8ConversionInsts]
 >;
 
+def FeatureWMMA128bInsts : SubtargetFeature<"wmma-128b-insts",
+  "HasWMMA128bInsts",
+  "true",
+  "Has WMMA instructions where A and B matrices do not have duplicated data"
+>;
+
 def FeaturePkFmacF16Inst : SubtargetFeature<"pk-fmac-f16-inst",
   "HasPkFmacF16Inst",
   "true",
@@ -1925,6 +1931,7 @@ def FeatureISAVersion12 : FeatureSet<
    FeatureImageInsts,
    FeatureExtendedImageInsts,
    FeatureFP8ConversionInsts,
+   FeatureWMMA128bInsts,
    FeatureIEEEMinimumMaximumInsts,
    FeaturePackedTID,
    FeatureVcmpxPermlaneHazard,
@@ -2291,6 +2298,10 @@ def isGFX11Plus :
   Predicate<"Subtarget->getGeneration() >= AMDGPUSubtarget::GFX11">,
   AssemblerPredicate<(all_of FeatureGFX11Insts)>;
 
+def isGFX11PlusNot12_50 :
+  Predicate<"Subtarget->getGeneration() >= AMDGPUSubtarget::GFX11 && !Subtarget->hasGFX1250Insts()">,
+  AssemblerPredicate<(all_of FeatureGFX11Insts, (not FeatureGFX1250Insts))>;
+
 def isGFX12Only :
   Predicate<"Subtarget->getGeneration() == AMDGPUSubtarget::GFX12">,
   AssemblerPredicate<(all_of FeatureGFX12Insts)>;
@@ -2616,6 +2627,9 @@ def HasFP8Insts : Predicate<"Subtarget->hasFP8Insts()">,
 def HasFP8ConversionInsts : Predicate<"Subtarget->hasFP8ConversionInsts()">,
   AssemblerPredicate<(all_of FeatureFP8ConversionInsts)>;
 
+def HasWMMA128bInsts : Predicate<"Subtarget->hasWMMA128bInsts()">,
+  AssemblerPredicate<(all_of FeatureWMMA128bInsts)>;
+
 def HasFP8E5M3Insts : Predicate<"Subtarget->hasFP8E5M3Insts()">,
   AssemblerPredicate<(all_of FeatureFP8E5M3Insts)>;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..7b5d4077e85f3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -55,6 +55,14 @@ def gi_vop3pmodsneg :
     GIComplexOperandMatcher<s32, "selectVOP3PModsNeg">,
     GIComplexPatternEquiv<VOP3PModsNeg>;
 
+def gi_vop3pmodsnegs :
+    GIComplexOperandMatcher<s32, "selectVOP3PModsNegs">,
+    GIComplexPatternEquiv<VOP3PModsNegs>;
+
+def gi_dotiuvop3pmodsnegabs :
+    GIComplexOperandMatcher<s32, "selectVOP3PModsNegAbs">,
+    GIComplexPatternEquiv<VOP3PModsNegAbs>;
+
 def gi_wmmaopselvop3pmods :
     GIComplexOperandMatcher<s32, "selectWMMAOpSelVOP3PMods">,
     GIComplexPatternEquiv<WMMAOpSelVOP3PMods>;
@@ -83,6 +91,10 @@ def gi_swmmacindex16 :
     GIComplexOperandMatcher<s32, "selectSWMMACIndex16">,
     GIComplexPatternEquiv<SWMMACIndex16>;
 
+def gi_swmmacindex32 :
+    GIComplexOperandMatcher<s64, "selectSWMMACIndex32">,
+    GIComplexPatternEquiv<SWMMACIndex32>;
+
 def gi_vop3opselmods :
     GIComplexOperandMatcher<s32, "selectVOP3OpSelMods">,
     GIComplexPatternEquiv<VOP3OpSelMods>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 620eac428c084..25672a52345cb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3273,6 +3273,7 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsDOT(SDValue In, SDValue &Src,
   return SelectVOP3PMods(In, Src, SrcMods, true);
 }
 
+// Select neg_lo from the i1 immediate operand.
 bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
   const ConstantSDNode *C = cast<ConstantSDNode>(In);
   // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
@@ -3288,6 +3289,47 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
   return true;
 }
 
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is
+// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies
+// to matrix's even k elements, and neg_hi applies to matrix's odd k elements.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegs(SDValue In, SDValue &Src) const {
+  const ConstantSDNode *C = cast<ConstantSDNode>(In);
+  // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+  // 1 promotes packed values to signed, 0 treats them as unsigned.
+  assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value");
+
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  unsigned SrcSign = C->getZExtValue();
+  if (SrcSign == 1)
+    Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+
+  Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const {
+  const ConstantSDNode *C = cast<ConstantSDNode>(In);
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  unsigned SrcMod = C->getZExtValue();
+  switch (SrcMod) {
+  default: // Any other value will be silently ignored (considered as 0).
+    break;
+  case 1:
+    Mods ^= SISrcMods::NEG;
+    break;
+  case 2:
+    Mods ^= SISrcMods::ABS;
+    break;
+  case 3:
+    Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+    break;
+  }
+
+  Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
 bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
                                                   SDValue &Src) const {
   const ConstantSDNode *C = cast<ConstantSDNode>(In);
@@ -3639,6 +3681,41 @@ bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src,
   return true;
 }
 
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex32(SDValue In, SDValue &Src,
+                                             SDValue &IndexKey) const {
+  unsigned Key = 0;
+  Src = In;
+
+  SDValue InI32;
+
+  if (In.getOpcode() == ISD::ANY_EXTEND || In.getOpcode() == ISD::ZERO_EXTEND) {
+    const SDValue &ExtendSrc = In.getOperand(0);
+    if (ExtendSrc.getValueSizeInBits() == 32)
+      InI32 = ExtendSrc;
+  } else if (In->getOpcode() == ISD::BITCAST) {
+    const SDValue &CastSrc = In.getOperand(0);
+    if (CastSrc.getOpcode() == ISD::BUILD_VECTOR &&
+        CastSrc.getOperand(0).getValueSizeInBits() == 32) {
+      ConstantSDNode *Zero = dyn_cast<ConstantSDNode>(CastSrc.getOperand(1));
+      if (Zero && Zero->getZExtValue() == 0)
+        InI32 = CastSrc.getOperand(0);
+    }
+  }
+
+  if (InI32 && InI32.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+    const SDValue &ExtractVecEltSrc = InI32.getOperand(0);
+    ConstantSDNode *EltIdx = dyn_cast<ConstantSDNode>(InI32.getOperand(1));
+    if (ExtractVecEltSrc.getValueSizeInBits() == 64 && EltIdx &&
+        EltIdx->getZExtValue() == 1) {
+      Key = 1;
+      Src = ExtractVecEltSrc;
+    }
+  }
+
+  IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+  return true;
+}
+
 bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src,
                                          SDValue &SrcMods) const {
   Src = In;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
index f3b9364fdb92b..9967f46e085e4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
@@ -222,6 +222,8 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
   bool SelectVOP3PModsDOT(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
   bool SelectVOP3PModsNeg(SDValue In, SDValue &Src) const;
+  bool SelectVOP3PModsNegs(SDValue In, SDValue &Src) const;
+  bool SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const;
   bool SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const;
 
   bool SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
@@ -233,6 +235,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
 
   bool SelectSWMMACIndex8(SDValue In, SDValue &Src, SDValue &IndexKey) const;
   bool SelectSWMMACIndex16(SDValue In, SDValue &Src, SDValue &IndexKey) const;
+  bool SelectSWMMACIndex32(SDValue In, SDValue &Src, SDValue &IndexKey) const;
 
   bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index ea79c57080faa..1a63c48e3666c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -3513,6 +3513,25 @@ static Register matchZeroExtendFromS32(MachineRegisterInfo &MRI, Register Reg) {
   return Register();
 }
 
+Register AMDGPUInstructionSelector::matchAnyExtendFromS32(Register Reg) const {
+  Register AnyExtSrc;
+  if (mi_match(Reg, *MRI, m_GAnyExt(m_Reg(AnyExtSrc))))
+    return MRI->getType(AnyExtSrc) == LLT::scalar(32) ? AnyExtSrc : Register();
+
+  // Match legalized form %zext = G_MERGE_VALUES (s32 %x), (s32 G_IMPLICIT_DEF)
+  const MachineInstr *Def = getDefIgnoringCopies(Reg, *MRI);
+  if (Def->getOpcode() != AMDGPU::G_MERGE_VALUES)
+    return Register();
+
+  assert(Def->getNumOperands() == 3 &&
+         MRI->getType(Def->getOperand(0).getReg()) == LLT::scalar(64));
+
+  if (mi_match(Def->getOperand(2).getReg(), *MRI, m_GImplicitDef()))
+    return Def->getOperand(1).getReg();
+
+  return Register();
+}
+
 bool AMDGPUInstructionSelector::selectGlobalLoadLds(MachineInstr &MI) const{
   if (!Subtarget->hasVMemToLDSLoad())
     return false;
@@ -4904,6 +4923,7 @@ AMDGPUInstructionSelector::selectVOP3PModsDOT(MachineOperand &Root) const {
   return selectVOP3PRetHelper(Root, true);
 }
 
+// Select neg_lo from the i1 immediate operand.
 InstructionSelector::ComplexRendererFns
 AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
   // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
@@ -4919,6 +4939,50 @@ AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
   }};
 }
 
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is
+// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies
+// to matrix's even k elements, and neg_hi applies to matrix's odd k elements.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegs(MachineOperand &Root) const {
+  // Literal i1 va...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/148957


More information about the llvm-commits mailing list