[llvm] [AMDGPU] gfx1250 v_wmma_ld_scale instructions (PR #152010)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 4 11:06:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mc
Author: Stanislav Mekhanoshin (rampitec)
<details>
<summary>Changes</summary>
---
Patch is 51.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152010.diff
10 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+84)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+69)
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+13)
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+11)
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+8)
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+78-42)
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+21-7)
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+170)
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+10)
- (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+90)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a83caa0db8a69..d33765db9cc7d 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyBitOp3,
ImmTyMatrixAFMT,
ImmTyMatrixBFMT,
+ ImmTyMatrixAScale,
+ ImmTyMatrixBScale,
+ ImmTyMatrixAScaleFmt,
+ ImmTyMatrixBScaleFmt,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyScaleSel,
@@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+ bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+ bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+ bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+ bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyBitOp3: OS << "BitOp3"; break;
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+ case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+ case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+ case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+ case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyScaleSel: OS << "ScaleSel" ; break;
@@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
AMDGPUOperand::ImmTy Type);
ParseStatus parseMatrixAFMT(OperandVector &Operands);
ParseStatus parseMatrixBFMT(OperandVector &Operands);
+ ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScale(OperandVector &Operands);
+ ParseStatus parseMatrixBScale(OperandVector &Operands);
+ ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+ ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
AMDGPUOperand::ImmTyMatrixBFMT);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(
+ Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
+ return tryParseMatrixScale(Operands, "matrix_a_scale",
+ AMDGPUOperand::ImmTyMatrixAScale);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
+ return tryParseMatrixScale(Operands, "matrix_b_scale",
+ AMDGPUOperand::ImmTyMatrixBScale);
+}
+
+ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(
+ Operands, Name,
+ {"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
+ return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
+ AMDGPUOperand::ImmTyMatrixAScaleFmt);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
+ return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
+ AMDGPUOperand::ImmTyMatrixBScaleFmt);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
AMDGPUOperand::ImmTyMatrixBFMT, 0);
}
+ int MatrixAScaleIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
+ if (MatrixAScaleIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAScale, 0);
+ }
+
+ int MatrixBScaleIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
+ if (MatrixBScaleIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBScale, 0);
+ }
+
+ int MatrixAScaleFmtIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
+ if (MatrixAScaleFmtIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
+ }
+
+ int MatrixBScaleFmtIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
+ if (MatrixBScaleFmtIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index 42c4d8b8a9717..ee8683a549a80 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
printMatrixFMT(MI, OpNo, STI, O, 'b');
}
+void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O, char AorB) {
+ auto Imm = MI->getOperand(OpNo).getImm() & 1;
+ if (Imm == 0)
+ return;
+
+ O << " matrix_" << AorB << "_scale:";
+ switch (Imm) {
+ default:
+ O << Imm;
+ break;
+ case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
+ O << "MATRIX_SCALE_ROW0";
+ break;
+ case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
+ O << "MATRIX_SCALE_ROW1";
+ break;
+ }
+}
+
+void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScale(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScale(MI, OpNo, STI, O, 'b');
+}
+
+void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O, char AorB) {
+ auto Imm = MI->getOperand(OpNo).getImm() & 3;
+ if (Imm == 0)
+ return;
+
+ O << " matrix_" << AorB << "_scale_fmt:";
+ switch (Imm) {
+ default:
+ O << Imm;
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
+ O << "MATRIX_SCALE_FMT_E8";
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
+ O << "MATRIX_SCALE_FMT_E5M3";
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
+ O << "MATRIX_SCALE_FMT_E4M3";
+ break;
+ }
+}
+
+void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
+}
+
void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI,
raw_ostream &O) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
index f6739b14926e1..be32061c64537 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
@@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
+ void printMatrixAScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixBScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O,
+ char AorB);
+ void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpSlot(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpAttr(const MCInst *MI, unsigned OpNo,
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index c56414519a6fe..deadb7aed0f69 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
MATRIX_FMT_BF6 = 3,
MATRIX_FMT_FP4 = 4
};
+
+enum MatrixScale : unsigned {
+ MATRIX_SCALE_ROW0 = 0,
+ MATRIX_SCALE_ROW1 = 1,
+};
+
+enum MatrixScaleFmt : unsigned {
+ MATRIX_SCALE_FMT_E8 = 0,
+ MATRIX_SCALE_FMT_E5M3 = 1,
+ MATRIX_SCALE_FMT_E4M3 = 2
+};
} // namespace WMMA
namespace VOP3PEncoding {
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 4698a5805ee0c..50914a5ef231f 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;
+def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
+def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;
+
+def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
+def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;
+
def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
@@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
field bit HasNeg = HasModifiers;
field bit HasMatrixReuse = 0;
field bit HasMatrixFMT = 0;
+ field bit HasMatrixScale = 0;
+ field bit HasMatrixReuse = 0;
field bit HasSrc0Mods = HasModifiers;
field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 95fcd4ac1c101..457c0eed4f047 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1407,9 +1407,9 @@ let WaveSizePredicate = isWave64 in {
}
class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
- bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
- bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0,
- bit _IsF4 = 0>
+ bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
+ bit _HasMatrixFMT = 0, bit _HasMatrixScale = 0,
+ bit _Scale16 = 0, bit _HasMatrixReuse = 0, bit _IsF4 = 0>
: VOP3P_Profile<VOPProfile<ArgTy>> {
bit IsIU = _IsIU;
bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B
@@ -1417,6 +1417,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
int IndexType = _IndexType;
let HasMatrixFMT = _HasMatrixFMT;
+ let HasMatrixScale = _HasMatrixScale;
+ bit Scale16 = _Scale16;
let HasMatrixReuse = _HasMatrixReuse;
bit HasIModOp = _Has_ImodOp;
@@ -1455,6 +1457,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
IsC_F16: "_f16",
IsC_BF16: "_bf16",
1: "_b32")));
+ ValueType ScaleTy = !if(Scale16, i64, i32);
// For f16 and bf16 matrices A and B, each element can be modified by
// fneg(neg_lo,neg_hi = 1). For f32 and f64, neg_lo[0:1] is allowed, but
@@ -1516,6 +1519,13 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
!eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit));
dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt),
(ins));
+ dag MatrixScaleSrc = !if(HasMatrixScale,
+ !if(Scale16, (ins VCSrc_b64:$scale_src0, VCSrc_b64:$scale_src1),
+ (ins VCSrc_b32:$scale_src0, VCSrc_b32:$scale_src1)),
+ (ins));
+ dag MatrixScale = !if(HasMatrixScale, (ins MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale,
+ MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt),
+ (ins));
dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins));
dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins));
dag Neg = !cond(!and(NegLoAny, NegHiAny) : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi),
@@ -1529,7 +1539,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
(ins VRegSrc_64:$src2),
(ins VRegSrc_32:$src2)),
IndexKey)),
- MatrixFMT, MatrixReuse, Clamp, Neg);
+ MatrixScaleSrc, MatrixFMT, MatrixScale, MatrixReuse, Clamp, Neg);
// asm
@@ -1538,13 +1548,15 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
!eq(IndexType, 16) : "$index_key_16bit",
!eq(IndexType, 32) : "$index_key_32bit");
string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", "");
+ string MatrixScaleSrcAsm = !if(HasMatrixScale, ", $scale_src0, $scale_src1", "");
+ string MatrixScaleAsm = !if(HasMatrixScale, "$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt", "");
string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", "");
string ClampAsm = !if(HasClamp, "$clamp", "");
string NegAsm = !cond(!and(NegLoAny, NegHiAny) : "$neg_lo$neg_hi",
!and(NegLoAny, !not(NegHiAny)) : "$neg_lo",
!and(!not(NegLoAny), !not(NegHiAny)) : "");
- let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm;
+ let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixScaleSrcAsm#MatrxFMTAsm#MatrixScaleAsm#MatrixReuseAsm#NegAsm#ClampAsm;
// isel patterns
bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp));
@@ -1606,20 +1618,27 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins));
dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2))));
dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1, (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2));
+ dag MatrixScaleInPat = !if(HasMatrixScale, (ins timm:$matrix_a_scale, timm:$matrix_a_scale_fmt, ScaleTy:$scale_src0,
+ timm:$matrix_b_scale, timm:$matrix_b_scale_fmt, ScaleTy:$scale_src1),
+ (ins));
dag MatrixReuseInPat = !if(HasMatrixReuse, (ins timm:$matrix_a_reuse, timm:$matrix_b_reuse), (ins));
+ dag MatrixScaleOutSrcPat = !if(HasMatrixScale, (ins ScaleTy:$scale_src0, ScaleTy:$scale_src1), (ins));
+ dag MatrixScaleOutModPat = !if(HasMatrixScale, (ins i32:$matrix_a_scale, i32:$matrix_b_scale, i32:$matrix_a_scale_fmt, i32:$matrix_b_scale_fmt), (ins));
dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins));
- dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat);
- dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+ dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+ dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixScaleOutSrcPat, MatrixFMTOutPat,
+ MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
dag SwmmacInPat = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat);
dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat);
// wmma pattern where src2 is inline imm uses _threeaddr pseudo,
// can't use _twoaddr since it would violate src2 tied to vdst constraint.
- dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat);
- dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+ dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+ dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixScaleOutSrcPat,
+ MatrixFMTOutPat, MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
}
def WMMAInstInfoTable : GenericTable {
@@ -1728,39 +1747,51 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32, i32, v2i32, v4f32], 1,
// *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored
// for matrix A, index is i16; Matrix B uses all lanes
-def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F16_F16X32_WMMA_w32 ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/152010
More information about the llvm-commits
mailing list