[llvm] AMDGPU: Implement MC layer support for gfx1250 wmma instructions. (PR #148570)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 13 23:38:12 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mc

@llvm/pr-subscribers-backend-amdgpu

Author: Changpeng Fang (changpeng)

<details>
<summary>Changes</summary>



---

Patch is 331.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148570.diff


25 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPU.td (+10) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+12) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+76) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+92) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h (+9) 
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+198-3) 
- (modified) llvm/lib/Target/AMDGPU/CMakeLists.txt (+2-1) 
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+53-1) 
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1) 
- (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+121) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+21) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp (+3-2) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+9-3) 
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+21) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+21) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+2) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+6) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+441-84) 
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+36-11) 
- (added) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+1739) 
- (added) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+490) 
- (added) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+1001) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 91ace4d2b7f16..3507d0fdefd5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -838,6 +838,12 @@ def FeatureCvtFP8VOP1Bug : SubtargetFeature<"cvt-fp8-vop1-bug",
   [FeatureFP8ConversionInsts]
 >;
 
+def FeatureWMMA128bInsts : SubtargetFeature<"wmma-128b-insts",
+  "HasWMMA128bInsts",
+  "true",
+  "Has WMMA instructions where A and B matrices do not have duplicated data"
+>;
+
 def FeaturePkFmacF16Inst : SubtargetFeature<"pk-fmac-f16-inst",
   "HasPkFmacF16Inst",
   "true",
@@ -1919,6 +1925,7 @@ def FeatureISAVersion12 : FeatureSet<
    FeatureImageInsts,
    FeatureExtendedImageInsts,
    FeatureFP8ConversionInsts,
+   FeatureWMMA128bInsts,
    FeatureIEEEMinimumMaximumInsts,
    FeaturePackedTID,
    FeatureVcmpxPermlaneHazard,
@@ -2602,6 +2609,9 @@ def HasFP8Insts : Predicate<"Subtarget->hasFP8Insts()">,
 def HasFP8ConversionInsts : Predicate<"Subtarget->hasFP8ConversionInsts()">,
   AssemblerPredicate<(all_of FeatureFP8ConversionInsts)>;
 
+def HasWMMA128bInsts : Predicate<"Subtarget->hasWMMA128bInsts()">,
+  AssemblerPredicate<(all_of FeatureWMMA128bInsts)>;
+
 def HasFP8E5M3Insts : Predicate<"Subtarget->hasFP8E5M3Insts()">,
   AssemblerPredicate<(all_of FeatureFP8E5M3Insts)>;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..7b5d4077e85f3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -55,6 +55,14 @@ def gi_vop3pmodsneg :
     GIComplexOperandMatcher<s32, "selectVOP3PModsNeg">,
     GIComplexPatternEquiv<VOP3PModsNeg>;
 
+def gi_vop3pmodsnegs :
+    GIComplexOperandMatcher<s32, "selectVOP3PModsNegs">,
+    GIComplexPatternEquiv<VOP3PModsNegs>;
+
+def gi_dotiuvop3pmodsnegabs :
+    GIComplexOperandMatcher<s32, "selectVOP3PModsNegAbs">,
+    GIComplexPatternEquiv<VOP3PModsNegAbs>;
+
 def gi_wmmaopselvop3pmods :
     GIComplexOperandMatcher<s32, "selectWMMAOpSelVOP3PMods">,
     GIComplexPatternEquiv<WMMAOpSelVOP3PMods>;
@@ -83,6 +91,10 @@ def gi_swmmacindex16 :
     GIComplexOperandMatcher<s32, "selectSWMMACIndex16">,
     GIComplexPatternEquiv<SWMMACIndex16>;
 
+def gi_swmmacindex32 :
+    GIComplexOperandMatcher<s64, "selectSWMMACIndex32">,
+    GIComplexPatternEquiv<SWMMACIndex32>;
+
 def gi_vop3opselmods :
     GIComplexOperandMatcher<s32, "selectVOP3OpSelMods">,
     GIComplexPatternEquiv<VOP3OpSelMods>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 202693b316122..7a8391b52ab0b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3260,6 +3260,47 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
   return true;
 }
 
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is specifically
+// for F16/BF16 operands in WMMA instructions, where neg_lo applies to matrix's even
+// k elements, and neg_hi applies to matrix's odd k elements.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegs(SDValue In, SDValue &Src) const {
+  const ConstantSDNode *C = cast<ConstantSDNode>(In);
+  // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+  // 1 promotes packed values to signed, 0 treats them as unsigned.
+  assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value");
+
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  unsigned SrcSign = C->getZExtValue();
+  if (SrcSign == 1)
+    Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+
+  Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const {
+  const ConstantSDNode *C = cast<ConstantSDNode>(In);
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  unsigned SrcMod = C->getZExtValue();
+  switch (SrcMod) {
+  default: // Any other value will be silently ignored (considered as 0).
+    break;
+  case 1:
+    Mods ^= SISrcMods::NEG;
+    break;
+  case 2:
+    Mods ^= SISrcMods::ABS;
+    break;
+  case 3:
+    Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+    break;
+  }
+
+  Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
 bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
                                                   SDValue &Src) const {
   const ConstantSDNode *C = cast<ConstantSDNode>(In);
@@ -3611,6 +3652,41 @@ bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src,
   return true;
 }
 
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex32(SDValue In, SDValue &Src,
+                                             SDValue &IndexKey) const {
+  unsigned Key = 0;
+  Src = In;
+
+  SDValue InI32;
+
+  if (In.getOpcode() == ISD::ANY_EXTEND || In.getOpcode() == ISD::ZERO_EXTEND) {
+    const SDValue &ExtendSrc = In.getOperand(0);
+    if (ExtendSrc.getValueSizeInBits() == 32)
+      InI32 = ExtendSrc;
+  } else if (In->getOpcode() == ISD::BITCAST) {
+    const SDValue &CastSrc = In.getOperand(0);
+    if (CastSrc.getOpcode() == ISD::BUILD_VECTOR &&
+        CastSrc.getOperand(0).getValueSizeInBits() == 32) {
+      ConstantSDNode *Zero = dyn_cast<ConstantSDNode>(CastSrc.getOperand(1));
+      if (Zero && Zero->getZExtValue() == 0)
+        InI32 = CastSrc.getOperand(0);
+    }
+  }
+
+  if (InI32 && InI32.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+    const SDValue &ExtractVecEltSrc = InI32.getOperand(0);
+    ConstantSDNode *EltIdx = dyn_cast<ConstantSDNode>(InI32.getOperand(1));
+    if (ExtractVecEltSrc.getValueSizeInBits() == 64 && EltIdx &&
+        EltIdx->getZExtValue() == 1) {
+      Key = 1;
+      Src = ExtractVecEltSrc;
+    }
+  }
+
+  IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+  return true;
+}
+
 bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src,
                                          SDValue &SrcMods) const {
   Src = In;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
index f3b9364fdb92b..9967f46e085e4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
@@ -222,6 +222,8 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
   bool SelectVOP3PModsDOT(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
   bool SelectVOP3PModsNeg(SDValue In, SDValue &Src) const;
+  bool SelectVOP3PModsNegs(SDValue In, SDValue &Src) const;
+  bool SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const;
   bool SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const;
 
   bool SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
@@ -233,6 +235,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
 
   bool SelectSWMMACIndex8(SDValue In, SDValue &Src, SDValue &IndexKey) const;
   bool SelectSWMMACIndex16(SDValue In, SDValue &Src, SDValue &IndexKey) const;
+  bool SelectSWMMACIndex32(SDValue In, SDValue &Src, SDValue &IndexKey) const;
 
   bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index ea79c57080faa..b3952305d24a2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -3513,6 +3513,25 @@ static Register matchZeroExtendFromS32(MachineRegisterInfo &MRI, Register Reg) {
   return Register();
 }
 
+Register AMDGPUInstructionSelector::matchAnyExtendFromS32(Register Reg) const {
+  Register AnyExtSrc;
+  if (mi_match(Reg, *MRI, m_GAnyExt(m_Reg(AnyExtSrc))))
+    return MRI->getType(AnyExtSrc) == LLT::scalar(32) ? AnyExtSrc : Register();
+
+  // Match legalized form %zext = G_MERGE_VALUES (s32 %x), (s32 G_IMPLICIT_DEF)
+  const MachineInstr *Def = getDefIgnoringCopies(Reg, *MRI);
+  if (Def->getOpcode() != AMDGPU::G_MERGE_VALUES)
+    return Register();
+
+  assert(Def->getNumOperands() == 3 &&
+         MRI->getType(Def->getOperand(0).getReg()) == LLT::scalar(64));
+
+  if (mi_match(Def->getOperand(2).getReg(), *MRI, m_GImplicitDef()))
+    return Def->getOperand(1).getReg();
+
+  return Register();
+}
+
 bool AMDGPUInstructionSelector::selectGlobalLoadLds(MachineInstr &MI) const{
   if (!Subtarget->hasVMemToLDSLoad())
     return false;
@@ -4919,6 +4938,50 @@ AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
   }};
 }
 
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is specifically
+// for F16/BF16 operands in WMMA instructions, where neg_lo applies to matrix's even
+// k elements, and neg_hi applies to matrix's odd k elements.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegs(MachineOperand &Root) const {
+  // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+  // Value is in Imm operand as i1 sign extended to int64_t.
+  // 1(-1) promotes packed values to signed, 0 treats them as unsigned.
+  assert((Root.isImm() && (Root.getImm() == -1 || Root.getImm() == 0)) &&
+         "expected i1 value");
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  if (Root.getImm() == -1)
+    Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+  return {{
+      [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+  }};
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegAbs(MachineOperand &Root) const {
+
+  assert(Root.isImm() && "Modifier for C must be an immediate");
+
+  unsigned Mods = SISrcMods::OP_SEL_1;
+  switch (Root.getImm()) {
+  default: // Any other value will be silently ignored (considered as 0).
+    break;
+  case 1:
+    Mods ^= SISrcMods::NEG;
+    break;
+  case 2:
+    Mods ^= SISrcMods::ABS;
+    break;
+  case 3:
+    Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+    break;
+  }
+
+  return {{
+      [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+  }};
+}
+
 InstructionSelector::ComplexRendererFns
 AMDGPUInstructionSelector::selectWMMAOpSelVOP3PMods(
     MachineOperand &Root) const {
@@ -5149,6 +5212,35 @@ AMDGPUInstructionSelector::selectSWMMACIndex16(MachineOperand &Root) const {
   }};
 }
 
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectSWMMACIndex32(MachineOperand &Root) const {
+  Register Src =
+      getDefIgnoringCopies(Root.getReg(), *MRI)->getOperand(0).getReg();
+  unsigned Key = 0;
+
+  Register S32 = matchZeroExtendFromS32(*MRI, Src);
+  if (!S32)
+    S32 = matchAnyExtendFromS32(Src);
+
+  if (S32) {
+    const MachineInstr *Def = getDefIgnoringCopies(S32, *MRI);
+    if (Def->getOpcode() == TargetOpcode::G_UNMERGE_VALUES) {
+      assert(Def->getNumOperands() == 3);
+      Register DstReg1 = Def->getOperand(1).getReg();
+      if (mi_match(S32, *MRI,
+                   m_any_of(m_SpecificReg(DstReg1), m_Copy(m_Reg(DstReg1))))) {
+        Src = Def->getOperand(2).getReg();
+        Key = 1;
+      }
+    }
+  }
+
+  return {{
+      [=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
+      [=](MachineInstrBuilder &MIB) { MIB.addImm(Key); } // index_key
+  }};
+}
+
 InstructionSelector::ComplexRendererFns
 AMDGPUInstructionSelector::selectVOP3OpSelMods(MachineOperand &Root) const {
   Register Src;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
index 8e9e573147a86..2cb7904d27ccc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
@@ -201,6 +201,10 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
 
   InstructionSelector::ComplexRendererFns
   selectVOP3PModsNeg(MachineOperand &Root) const;
+  InstructionSelector::ComplexRendererFns
+  selectVOP3PModsNegs(MachineOperand &Root) const;
+  InstructionSelector::ComplexRendererFns
+  selectVOP3PModsNegAbs(MachineOperand &Root) const;
 
   InstructionSelector::ComplexRendererFns
   selectWMMAOpSelVOP3PMods(MachineOperand &Root) const;
@@ -217,6 +221,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
   selectSWMMACIndex8(MachineOperand &Root) const;
   InstructionSelector::ComplexRendererFns
   selectSWMMACIndex16(MachineOperand &Root) const;
+  InstructionSelector::ComplexRendererFns
+  selectSWMMACIndex32(MachineOperand &Root) const;
 
   InstructionSelector::ComplexRendererFns
   selectVOP3OpSelMods(MachineOperand &Root) const;
@@ -411,6 +417,9 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
   // shift amount operand's `ShAmtBits` bits is unneeded.
   bool isUnneededShiftMask(const MachineInstr &MI, unsigned ShAmtBits) const;
 
+  /// Match an any extend from a 32-bit value to 64-bit.
+  Register matchAnyExtendFromS32(Register Reg) const;
+
   const SIInstrInfo &TII;
   const SIRegisterInfo &TRI;
   const AMDGPURegisterBankInfo &RBI;
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 35de49c27b32a..35be8338dac6f 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -157,6 +157,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyNegHi,
     ImmTyIndexKey8bit,
     ImmTyIndexKey16bit,
+    ImmTyIndexKey32bit,
     ImmTyDPP8,
     ImmTyDppCtrl,
     ImmTyDppRowMask,
@@ -174,8 +175,16 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitEXP,
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
-    ImmTyByteSel,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
+    ImmTyMatrixAScale,
+    ImmTyMatrixBScale,
+    ImmTyMatrixAScaleFmt,
+    ImmTyMatrixBScaleFmt,
+    ImmTyMatrixAReuse,
+    ImmTyMatrixBReuse,
+    ImmTyByteSel,
   };
 
   // Immediate operand kind.
@@ -419,6 +428,15 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isCPol() const { return isImmTy(ImmTyCPol); }
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
+  bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+  bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+  bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+  bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+  bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
+  bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
+  bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
   bool isFORMAT() const { return isImmTy(ImmTyFORMAT) && isUInt<7>(getImm()); }
   bool isDppFI() const { return isImmTy(ImmTyDppFI); }
@@ -747,6 +765,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VReg_256RegClassID, MVT::f64);
   }
 
+  bool isVISrc_512_f64() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_512RegClassID, MVT::f64);
+  }
+
   bool isVISrc_128B16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::i16);
   }
@@ -1114,6 +1136,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyCPol: OS << "CPol"; break;
     case ImmTyIndexKey8bit: OS << "index_key"; break;
     case ImmTyIndexKey16bit: OS << "index_key"; break;
+    case ImmTyIndexKey32bit: OS << "index_key"; break;
     case ImmTyTFE: OS << "TFE"; break;
     case ImmTyD16: OS << "D16"; break;
     case ImmTyFORMAT: OS << "FORMAT"; break;
@@ -1160,8 +1183,16 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitEXP: OS << "WaitEXP"; break;
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
-    case ImmTyByteSel: OS << "ByteSel" ; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+    case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+    case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+    case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+    case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
+    case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
+    case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
+    case ImmTyByteSel: OS << "ByteSel" ; break;
     }
     // clang-format on
   }
@@ -1698,6 +1729,19 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                                AMDGPUOperand::ImmTy ImmTy);
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
+  ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
+  ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+                                  AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScale(OperandVector &Operands);
+  ParseStatus parseMatrixBScale(OperandVector &Operands);
+  ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+                                     AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+  ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1833,6 +1877,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5366,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5499,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7133,7 +7212,9 @@ ParseStatus AMDGPUAsmParser::tryParseIndexKey(OperandVector &Operands,
   if (!Res.isSuccess())
     return Res;
 
-  if (ImmTy == AMDGPUOperand::...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/148570


More information about the llvm-commits mailing list