[llvm] AMDGPU: Implement MC layer support for gfx1250 wmma instructions. (PR #148570)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 13 23:38:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mc
@llvm/pr-subscribers-backend-amdgpu
Author: Changpeng Fang (changpeng)
<details>
<summary>Changes</summary>
---
Patch is 331.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148570.diff
25 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/AMDGPU.td (+10)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+12)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+76)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h (+3)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+92)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h (+9)
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+198-3)
- (modified) llvm/lib/Target/AMDGPU/CMakeLists.txt (+2-1)
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+53-1)
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1)
- (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.h (+3)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+121)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+21)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp (+3-2)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+9-3)
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+21)
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+21)
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+2)
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23)
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+6)
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+441-84)
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+36-11)
- (added) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+1739)
- (added) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+490)
- (added) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+1001)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 91ace4d2b7f16..3507d0fdefd5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -838,6 +838,12 @@ def FeatureCvtFP8VOP1Bug : SubtargetFeature<"cvt-fp8-vop1-bug",
[FeatureFP8ConversionInsts]
>;
+def FeatureWMMA128bInsts : SubtargetFeature<"wmma-128b-insts",
+ "HasWMMA128bInsts",
+ "true",
+ "Has WMMA instructions where A and B matrices do not have duplicated data"
+>;
+
def FeaturePkFmacF16Inst : SubtargetFeature<"pk-fmac-f16-inst",
"HasPkFmacF16Inst",
"true",
@@ -1919,6 +1925,7 @@ def FeatureISAVersion12 : FeatureSet<
FeatureImageInsts,
FeatureExtendedImageInsts,
FeatureFP8ConversionInsts,
+ FeatureWMMA128bInsts,
FeatureIEEEMinimumMaximumInsts,
FeaturePackedTID,
FeatureVcmpxPermlaneHazard,
@@ -2602,6 +2609,9 @@ def HasFP8Insts : Predicate<"Subtarget->hasFP8Insts()">,
def HasFP8ConversionInsts : Predicate<"Subtarget->hasFP8ConversionInsts()">,
AssemblerPredicate<(all_of FeatureFP8ConversionInsts)>;
+def HasWMMA128bInsts : Predicate<"Subtarget->hasWMMA128bInsts()">,
+ AssemblerPredicate<(all_of FeatureWMMA128bInsts)>;
+
def HasFP8E5M3Insts : Predicate<"Subtarget->hasFP8E5M3Insts()">,
AssemblerPredicate<(all_of FeatureFP8E5M3Insts)>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..7b5d4077e85f3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -55,6 +55,14 @@ def gi_vop3pmodsneg :
GIComplexOperandMatcher<s32, "selectVOP3PModsNeg">,
GIComplexPatternEquiv<VOP3PModsNeg>;
+def gi_vop3pmodsnegs :
+ GIComplexOperandMatcher<s32, "selectVOP3PModsNegs">,
+ GIComplexPatternEquiv<VOP3PModsNegs>;
+
+def gi_dotiuvop3pmodsnegabs :
+ GIComplexOperandMatcher<s32, "selectVOP3PModsNegAbs">,
+ GIComplexPatternEquiv<VOP3PModsNegAbs>;
+
def gi_wmmaopselvop3pmods :
GIComplexOperandMatcher<s32, "selectWMMAOpSelVOP3PMods">,
GIComplexPatternEquiv<WMMAOpSelVOP3PMods>;
@@ -83,6 +91,10 @@ def gi_swmmacindex16 :
GIComplexOperandMatcher<s32, "selectSWMMACIndex16">,
GIComplexPatternEquiv<SWMMACIndex16>;
+def gi_swmmacindex32 :
+ GIComplexOperandMatcher<s64, "selectSWMMACIndex32">,
+ GIComplexPatternEquiv<SWMMACIndex32>;
+
def gi_vop3opselmods :
GIComplexOperandMatcher<s32, "selectVOP3OpSelMods">,
GIComplexPatternEquiv<VOP3OpSelMods>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 202693b316122..7a8391b52ab0b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3260,6 +3260,47 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
return true;
}
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is specifically
+// for F16/BF16 operands in WMMA instructions, where neg_lo applies to matrix's even
+// k elements, and neg_hi applies to matrix's odd k elements.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegs(SDValue In, SDValue &Src) const {
+ const ConstantSDNode *C = cast<ConstantSDNode>(In);
+ // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+ // 1 promotes packed values to signed, 0 treats them as unsigned.
+ assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value");
+
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned SrcSign = C->getZExtValue();
+ if (SrcSign == 1)
+ Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+
+ Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const {
+ const ConstantSDNode *C = cast<ConstantSDNode>(In);
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned SrcMod = C->getZExtValue();
+ switch (SrcMod) {
+ default: // Any other value will be silently ignored (considered as 0).
+ break;
+ case 1:
+ Mods ^= SISrcMods::NEG;
+ break;
+ case 2:
+ Mods ^= SISrcMods::ABS;
+ break;
+ case 3:
+ Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+ break;
+ }
+
+ Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
SDValue &Src) const {
const ConstantSDNode *C = cast<ConstantSDNode>(In);
@@ -3611,6 +3652,41 @@ bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src,
return true;
}
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex32(SDValue In, SDValue &Src,
+ SDValue &IndexKey) const {
+ unsigned Key = 0;
+ Src = In;
+
+ SDValue InI32;
+
+ if (In.getOpcode() == ISD::ANY_EXTEND || In.getOpcode() == ISD::ZERO_EXTEND) {
+ const SDValue &ExtendSrc = In.getOperand(0);
+ if (ExtendSrc.getValueSizeInBits() == 32)
+ InI32 = ExtendSrc;
+ } else if (In->getOpcode() == ISD::BITCAST) {
+ const SDValue &CastSrc = In.getOperand(0);
+ if (CastSrc.getOpcode() == ISD::BUILD_VECTOR &&
+ CastSrc.getOperand(0).getValueSizeInBits() == 32) {
+ ConstantSDNode *Zero = dyn_cast<ConstantSDNode>(CastSrc.getOperand(1));
+ if (Zero && Zero->getZExtValue() == 0)
+ InI32 = CastSrc.getOperand(0);
+ }
+ }
+
+ if (InI32 && InI32.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+ const SDValue &ExtractVecEltSrc = InI32.getOperand(0);
+ ConstantSDNode *EltIdx = dyn_cast<ConstantSDNode>(InI32.getOperand(1));
+ if (ExtractVecEltSrc.getValueSizeInBits() == 64 && EltIdx &&
+ EltIdx->getZExtValue() == 1) {
+ Key = 1;
+ Src = ExtractVecEltSrc;
+ }
+ }
+
+ IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+ return true;
+}
+
bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src,
SDValue &SrcMods) const {
Src = In;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
index f3b9364fdb92b..9967f46e085e4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
@@ -222,6 +222,8 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
bool SelectVOP3PModsDOT(SDValue In, SDValue &Src, SDValue &SrcMods) const;
bool SelectVOP3PModsNeg(SDValue In, SDValue &Src) const;
+ bool SelectVOP3PModsNegs(SDValue In, SDValue &Src) const;
+ bool SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const;
bool SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const;
bool SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
@@ -233,6 +235,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
bool SelectSWMMACIndex8(SDValue In, SDValue &Src, SDValue &IndexKey) const;
bool SelectSWMMACIndex16(SDValue In, SDValue &Src, SDValue &IndexKey) const;
+ bool SelectSWMMACIndex32(SDValue In, SDValue &Src, SDValue &IndexKey) const;
bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index ea79c57080faa..b3952305d24a2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -3513,6 +3513,25 @@ static Register matchZeroExtendFromS32(MachineRegisterInfo &MRI, Register Reg) {
return Register();
}
+Register AMDGPUInstructionSelector::matchAnyExtendFromS32(Register Reg) const {
+ Register AnyExtSrc;
+ if (mi_match(Reg, *MRI, m_GAnyExt(m_Reg(AnyExtSrc))))
+ return MRI->getType(AnyExtSrc) == LLT::scalar(32) ? AnyExtSrc : Register();
+
+ // Match legalized form %zext = G_MERGE_VALUES (s32 %x), (s32 G_IMPLICIT_DEF)
+ const MachineInstr *Def = getDefIgnoringCopies(Reg, *MRI);
+ if (Def->getOpcode() != AMDGPU::G_MERGE_VALUES)
+ return Register();
+
+ assert(Def->getNumOperands() == 3 &&
+ MRI->getType(Def->getOperand(0).getReg()) == LLT::scalar(64));
+
+ if (mi_match(Def->getOperand(2).getReg(), *MRI, m_GImplicitDef()))
+ return Def->getOperand(1).getReg();
+
+ return Register();
+}
+
bool AMDGPUInstructionSelector::selectGlobalLoadLds(MachineInstr &MI) const{
if (!Subtarget->hasVMemToLDSLoad())
return false;
@@ -4919,6 +4938,50 @@ AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
}};
}
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is specifically
+// for F16/BF16 operands in WMMA instructions, where neg_lo applies to matrix's even
+// k elements, and neg_hi applies to matrix's odd k elements.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegs(MachineOperand &Root) const {
+ // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+ // Value is in Imm operand as i1 sign extended to int64_t.
+ // 1(-1) promotes packed values to signed, 0 treats them as unsigned.
+ assert((Root.isImm() && (Root.getImm() == -1 || Root.getImm() == 0)) &&
+ "expected i1 value");
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ if (Root.getImm() == -1)
+ Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+ return {{
+ [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+ }};
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegAbs(MachineOperand &Root) const {
+
+ assert(Root.isImm() && "Modifier for C must be an immediate");
+
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ switch (Root.getImm()) {
+ default: // Any other value will be silently ignored (considered as 0).
+ break;
+ case 1:
+ Mods ^= SISrcMods::NEG;
+ break;
+ case 2:
+ Mods ^= SISrcMods::ABS;
+ break;
+ case 3:
+ Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+ break;
+ }
+
+ return {{
+ [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+ }};
+}
+
InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectWMMAOpSelVOP3PMods(
MachineOperand &Root) const {
@@ -5149,6 +5212,35 @@ AMDGPUInstructionSelector::selectSWMMACIndex16(MachineOperand &Root) const {
}};
}
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectSWMMACIndex32(MachineOperand &Root) const {
+ Register Src =
+ getDefIgnoringCopies(Root.getReg(), *MRI)->getOperand(0).getReg();
+ unsigned Key = 0;
+
+ Register S32 = matchZeroExtendFromS32(*MRI, Src);
+ if (!S32)
+ S32 = matchAnyExtendFromS32(Src);
+
+ if (S32) {
+ const MachineInstr *Def = getDefIgnoringCopies(S32, *MRI);
+ if (Def->getOpcode() == TargetOpcode::G_UNMERGE_VALUES) {
+ assert(Def->getNumOperands() == 3);
+ Register DstReg1 = Def->getOperand(1).getReg();
+ if (mi_match(S32, *MRI,
+ m_any_of(m_SpecificReg(DstReg1), m_Copy(m_Reg(DstReg1))))) {
+ Src = Def->getOperand(2).getReg();
+ Key = 1;
+ }
+ }
+ }
+
+ return {{
+ [=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
+ [=](MachineInstrBuilder &MIB) { MIB.addImm(Key); } // index_key
+ }};
+}
+
InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectVOP3OpSelMods(MachineOperand &Root) const {
Register Src;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
index 8e9e573147a86..2cb7904d27ccc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
@@ -201,6 +201,10 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
InstructionSelector::ComplexRendererFns
selectVOP3PModsNeg(MachineOperand &Root) const;
+ InstructionSelector::ComplexRendererFns
+ selectVOP3PModsNegs(MachineOperand &Root) const;
+ InstructionSelector::ComplexRendererFns
+ selectVOP3PModsNegAbs(MachineOperand &Root) const;
InstructionSelector::ComplexRendererFns
selectWMMAOpSelVOP3PMods(MachineOperand &Root) const;
@@ -217,6 +221,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
selectSWMMACIndex8(MachineOperand &Root) const;
InstructionSelector::ComplexRendererFns
selectSWMMACIndex16(MachineOperand &Root) const;
+ InstructionSelector::ComplexRendererFns
+ selectSWMMACIndex32(MachineOperand &Root) const;
InstructionSelector::ComplexRendererFns
selectVOP3OpSelMods(MachineOperand &Root) const;
@@ -411,6 +417,9 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
// shift amount operand's `ShAmtBits` bits is unneeded.
bool isUnneededShiftMask(const MachineInstr &MI, unsigned ShAmtBits) const;
+ /// Match an any extend from a 32-bit value to 64-bit.
+ Register matchAnyExtendFromS32(Register Reg) const;
+
const SIInstrInfo &TII;
const SIRegisterInfo &TRI;
const AMDGPURegisterBankInfo &RBI;
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 35de49c27b32a..35be8338dac6f 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -157,6 +157,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyNegHi,
ImmTyIndexKey8bit,
ImmTyIndexKey16bit,
+ ImmTyIndexKey32bit,
ImmTyDPP8,
ImmTyDppCtrl,
ImmTyDppRowMask,
@@ -174,8 +175,16 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyWaitEXP,
ImmTyWaitVAVDst,
ImmTyWaitVMVSrc,
- ImmTyByteSel,
ImmTyBitOp3,
+ ImmTyMatrixAFMT,
+ ImmTyMatrixBFMT,
+ ImmTyMatrixAScale,
+ ImmTyMatrixBScale,
+ ImmTyMatrixAScaleFmt,
+ ImmTyMatrixBScaleFmt,
+ ImmTyMatrixAReuse,
+ ImmTyMatrixBReuse,
+ ImmTyByteSel,
};
// Immediate operand kind.
@@ -419,6 +428,15 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isCPol() const { return isImmTy(ImmTyCPol); }
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
+ bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+ bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+ bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+ bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+ bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+ bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+ bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
+ bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
+ bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
bool isFORMAT() const { return isImmTy(ImmTyFORMAT) && isUInt<7>(getImm()); }
bool isDppFI() const { return isImmTy(ImmTyDppFI); }
@@ -747,6 +765,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isRegOrInlineNoMods(AMDGPU::VReg_256RegClassID, MVT::f64);
}
+ bool isVISrc_512_f64() const {
+ return isRegOrInlineNoMods(AMDGPU::VReg_512RegClassID, MVT::f64);
+ }
+
bool isVISrc_128B16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::i16);
}
@@ -1114,6 +1136,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyCPol: OS << "CPol"; break;
case ImmTyIndexKey8bit: OS << "index_key"; break;
case ImmTyIndexKey16bit: OS << "index_key"; break;
+ case ImmTyIndexKey32bit: OS << "index_key"; break;
case ImmTyTFE: OS << "TFE"; break;
case ImmTyD16: OS << "D16"; break;
case ImmTyFORMAT: OS << "FORMAT"; break;
@@ -1160,8 +1183,16 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyWaitEXP: OS << "WaitEXP"; break;
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
- case ImmTyByteSel: OS << "ByteSel" ; break;
case ImmTyBitOp3: OS << "BitOp3"; break;
+ case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+ case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+ case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+ case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+ case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+ case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
+ case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
+ case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
+ case ImmTyByteSel: OS << "ByteSel" ; break;
}
// clang-format on
}
@@ -1698,6 +1729,19 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
AMDGPUOperand::ImmTy ImmTy);
ParseStatus parseIndexKey8bit(OperandVector &Operands);
ParseStatus parseIndexKey16bit(OperandVector &Operands);
+ ParseStatus parseIndexKey32bit(OperandVector &Operands);
+ ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAFMT(OperandVector &Operands);
+ ParseStatus parseMatrixBFMT(OperandVector &Operands);
+ ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScale(OperandVector &Operands);
+ ParseStatus parseMatrixBScale(OperandVector &Operands);
+ ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+ ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -1833,6 +1877,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
const unsigned CPol);
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+ bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
unsigned getConstantBusLimit(unsigned Opcode) const;
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5366,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
return true;
}
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+ const OperandVector &Operands) {
+ unsigned Opc = Inst.getOpcode();
+ const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+ const MCInstrDesc &Desc = MII.get(Opc);
+
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+ int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+ if (FmtIdx == -1)
+ return true;
+ unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+ int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+ unsigned RegSize =
+ TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+ return true;
+
+ static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"};
+
+ Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+ "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+ return false;
+ };
+
+ return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+ validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
const SMLoc &IDLoc,
const OperandVector &Operands) {
@@ -5499,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
if (!validateTFE(Inst, Operands)) {
return false;
}
+ if (!validateWMMA(Inst, Operands)) {
+ return false;
+ }
return true;
}
@@ -7133,7 +7212,9 @@ ParseStatus AMDGPUAsmParser::tryParseIndexKey(OperandVector &Operands,
if (!Res.isSuccess())
return Res;
- if (ImmTy == AMDGPUOperand::...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/148570
More information about the llvm-commits
mailing list