[llvm] [AMDGPU] gfx1250 v_wmma_ld_scale instructions (PR #152010)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 4 11:06:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mc

Author: Stanislav Mekhanoshin (rampitec)

<details>
<summary>Changes</summary>



---

Patch is 51.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152010.diff


10 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+84) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+69) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+13) 
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+11) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+8) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+78-42) 
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+21-7) 
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+170) 
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+10) 
- (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+90) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a83caa0db8a69..d33765db9cc7d 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyBitOp3,
     ImmTyMatrixAFMT,
     ImmTyMatrixBFMT,
+    ImmTyMatrixAScale,
+    ImmTyMatrixBScale,
+    ImmTyMatrixAScaleFmt,
+    ImmTyMatrixBScaleFmt,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyScaleSel,
@@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
   bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
   bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+  bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+  bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+  bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+  bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyBitOp3: OS << "BitOp3"; break;
     case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
     case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+    case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+    case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+    case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+    case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyScaleSel: OS << "ScaleSel" ; break;
@@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                                 AMDGPUOperand::ImmTy Type);
   ParseStatus parseMatrixAFMT(OperandVector &Operands);
   ParseStatus parseMatrixBFMT(OperandVector &Operands);
+  ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+                                  AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScale(OperandVector &Operands);
+  ParseStatus parseMatrixBScale(OperandVector &Operands);
+  ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+                                     AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+  ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
                            AMDGPUOperand::ImmTyMatrixBFMT);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
+                                                 StringRef Name,
+                                                 AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(
+      Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
+  return tryParseMatrixScale(Operands, "matrix_a_scale",
+                             AMDGPUOperand::ImmTyMatrixAScale);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
+  return tryParseMatrixScale(Operands, "matrix_b_scale",
+                             AMDGPUOperand::ImmTyMatrixBScale);
+}
+
+ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
+                                                    StringRef Name,
+                                                    AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(
+      Operands, Name,
+      {"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
+      Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
+  return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
+                                AMDGPUOperand::ImmTyMatrixAScaleFmt);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
+  return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
+                                AMDGPUOperand::ImmTyMatrixBScaleFmt);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           AMDGPUOperand::ImmTyMatrixBFMT, 0);
   }
 
+  int MatrixAScaleIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
+  if (MatrixAScaleIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAScale, 0);
+  }
+
+  int MatrixBScaleIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
+  if (MatrixBScaleIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBScale, 0);
+  }
+
+  int MatrixAScaleFmtIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
+  if (MatrixAScaleFmtIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
+  }
+
+  int MatrixBScaleFmtIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
+  if (MatrixBScaleFmtIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index 42c4d8b8a9717..ee8683a549a80 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
   printMatrixFMT(MI, OpNo, STI, O, 'b');
 }
 
+void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
+                                         const MCSubtargetInfo &STI,
+                                         raw_ostream &O, char AorB) {
+  auto Imm = MI->getOperand(OpNo).getImm() & 1;
+  if (Imm == 0)
+    return;
+
+  O << " matrix_" << AorB << "_scale:";
+  switch (Imm) {
+  default:
+    O << Imm;
+    break;
+  case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
+    O << "MATRIX_SCALE_ROW0";
+    break;
+  case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
+    O << "MATRIX_SCALE_ROW1";
+    break;
+  }
+}
+
+void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
+                                          const MCSubtargetInfo &STI,
+                                          raw_ostream &O) {
+  printMatrixScale(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
+                                          const MCSubtargetInfo &STI,
+                                          raw_ostream &O) {
+  printMatrixScale(MI, OpNo, STI, O, 'b');
+}
+
+void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+                                            const MCSubtargetInfo &STI,
+                                            raw_ostream &O, char AorB) {
+  auto Imm = MI->getOperand(OpNo).getImm() & 3;
+  if (Imm == 0)
+    return;
+
+  O << " matrix_" << AorB << "_scale_fmt:";
+  switch (Imm) {
+  default:
+    O << Imm;
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
+    O << "MATRIX_SCALE_FMT_E8";
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
+    O << "MATRIX_SCALE_FMT_E5M3";
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
+    O << "MATRIX_SCALE_FMT_E4M3";
+    break;
+  }
+}
+
+void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+                                             const MCSubtargetInfo &STI,
+                                             raw_ostream &O) {
+  printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+                                             const MCSubtargetInfo &STI,
+                                             raw_ostream &O) {
+  printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
+}
+
 void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
                                         const MCSubtargetInfo &STI,
                                         raw_ostream &O) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
index f6739b14926e1..be32061c64537 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
@@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
                        const MCSubtargetInfo &STI, raw_ostream &O);
   void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
                        const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixScale(const MCInst *MI, unsigned OpNo,
+                        const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
+  void printMatrixAScale(const MCInst *MI, unsigned OpNo,
+                         const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixBScale(const MCInst *MI, unsigned OpNo,
+                         const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+                           const MCSubtargetInfo &STI, raw_ostream &O,
+                           char AorB);
+  void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+                            const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+                            const MCSubtargetInfo &STI, raw_ostream &O);
   void printInterpSlot(const MCInst *MI, unsigned OpNo,
                        const MCSubtargetInfo &STI, raw_ostream &O);
   void printInterpAttr(const MCInst *MI, unsigned OpNo,
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index c56414519a6fe..deadb7aed0f69 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
   MATRIX_FMT_BF6 = 3,
   MATRIX_FMT_FP4 = 4
 };
+
+enum MatrixScale : unsigned {
+  MATRIX_SCALE_ROW0 = 0,
+  MATRIX_SCALE_ROW1 = 1,
+};
+
+enum MatrixScaleFmt : unsigned {
+  MATRIX_SCALE_FMT_E8 = 0,
+  MATRIX_SCALE_FMT_E5M3 = 1,
+  MATRIX_SCALE_FMT_E4M3 = 2
+};
 } // namespace WMMA
 
 namespace VOP3PEncoding {
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 4698a5805ee0c..50914a5ef231f 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
 def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
 def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;
 
+def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
+def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;
+
+def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
+def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;
+
 def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
 def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
 
@@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
   field bit HasNeg = HasModifiers;
   field bit HasMatrixReuse = 0;
   field bit HasMatrixFMT = 0;
+  field bit HasMatrixScale = 0;
+  field bit HasMatrixReuse = 0;
 
   field bit HasSrc0Mods = HasModifiers;
   field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 95fcd4ac1c101..457c0eed4f047 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1407,9 +1407,9 @@ let WaveSizePredicate = isWave64 in {
 }
 
 class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
-                             bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
-                             bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0,
-                             bit _IsF4 = 0>
+                        bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
+                        bit _HasMatrixFMT = 0, bit _HasMatrixScale = 0,
+                        bit _Scale16 = 0, bit _HasMatrixReuse = 0, bit _IsF4 = 0>
     : VOP3P_Profile<VOPProfile<ArgTy>> {
   bit IsIU = _IsIU;
   bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B
@@ -1417,6 +1417,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
 
   int IndexType = _IndexType;
   let HasMatrixFMT = _HasMatrixFMT;
+  let HasMatrixScale = _HasMatrixScale;
+  bit Scale16 = _Scale16;
   let HasMatrixReuse = _HasMatrixReuse;
 
   bit HasIModOp = _Has_ImodOp;
@@ -1455,6 +1457,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                                                             IsC_F16:  "_f16",
                                                             IsC_BF16: "_bf16",
                                                             1: "_b32")));
+  ValueType ScaleTy = !if(Scale16, i64, i32);
 
   // For f16 and bf16 matrices A and B, each element can be modified by
   // fneg(neg_lo,neg_hi = 1). For f32 and f64, neg_lo[0:1] is allowed, but
@@ -1516,6 +1519,13 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                        !eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit));
   dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt),
                                    (ins));
+  dag MatrixScaleSrc = !if(HasMatrixScale,
+                           !if(Scale16, (ins VCSrc_b64:$scale_src0, VCSrc_b64:$scale_src1),
+                                        (ins VCSrc_b32:$scale_src0, VCSrc_b32:$scale_src1)),
+                           (ins));
+  dag MatrixScale = !if(HasMatrixScale, (ins MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale,
+                                             MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt),
+                                        (ins));
   dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins));
   dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins));
   dag Neg = !cond(!and(NegLoAny, NegHiAny)             : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi),
@@ -1529,7 +1539,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                                                  (ins VRegSrc_64:$src2),
                                                  (ins VRegSrc_32:$src2)),
                                             IndexKey)),
-                      MatrixFMT, MatrixReuse, Clamp, Neg);
+                      MatrixScaleSrc, MatrixFMT, MatrixScale, MatrixReuse, Clamp, Neg);
 
   // asm
 
@@ -1538,13 +1548,15 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                              !eq(IndexType, 16) : "$index_key_16bit",
                              !eq(IndexType, 32) : "$index_key_32bit");
   string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", "");
+  string MatrixScaleSrcAsm = !if(HasMatrixScale, ", $scale_src0, $scale_src1", "");
+  string MatrixScaleAsm = !if(HasMatrixScale, "$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt", "");
   string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", "");
   string ClampAsm = !if(HasClamp, "$clamp", "");
   string NegAsm = !cond(!and(NegLoAny, NegHiAny)             : "$neg_lo$neg_hi",
                         !and(NegLoAny, !not(NegHiAny))       : "$neg_lo",
                         !and(!not(NegLoAny), !not(NegHiAny)) : "");
 
-  let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm;
+  let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixScaleSrcAsm#MatrxFMTAsm#MatrixScaleAsm#MatrixReuseAsm#NegAsm#ClampAsm;
 
   // isel patterns
   bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp));
@@ -1606,20 +1618,27 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
   dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins));
   dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2))));
   dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1,  (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2));
+  dag MatrixScaleInPat = !if(HasMatrixScale, (ins timm:$matrix_a_scale, timm:$matrix_a_scale_fmt, ScaleTy:$scale_src0,
+                                                  timm:$matrix_b_scale, timm:$matrix_b_scale_fmt, ScaleTy:$scale_src1),
+                                             (ins));
 
   dag MatrixReuseInPat = !if(HasMatrixReuse, (ins timm:$matrix_a_reuse, timm:$matrix_b_reuse), (ins));
+  dag MatrixScaleOutSrcPat = !if(HasMatrixScale, (ins ScaleTy:$scale_src0, ScaleTy:$scale_src1), (ins));
+  dag MatrixScaleOutModPat = !if(HasMatrixScale, (ins i32:$matrix_a_scale, i32:$matrix_b_scale, i32:$matrix_a_scale_fmt, i32:$matrix_b_scale_fmt), (ins));
   dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins));
 
-  dag WmmaInPat  = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat);
-  dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+  dag WmmaInPat  = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+  dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixScaleOutSrcPat, MatrixFMTOutPat,
+                        MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
 
   dag SwmmacInPat  = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat);
   dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat);
 
   // wmma pattern where src2 is inline imm uses _threeaddr pseudo,
   // can't use _twoaddr since it would violate src2 tied to vdst constraint.
-  dag WmmaInlineInPat  = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat);
-  dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+  dag WmmaInlineInPat  = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+  dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixScaleOutSrcPat,
+                              MatrixFMTOutPat, MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
 }
 
 def WMMAInstInfoTable : GenericTable {
@@ -1728,39 +1747,51 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32,   i32, v2i32, v4f32], 1,
 // *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored
 //                       for matrix A, index is i16; Matrix B uses all lanes
 
-def F32_F32_WMMA_w32             : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_BF16X32_WMMA_w32         : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_F16X32_WMMA_w32          : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F16_F16X32_WMMA_w32        ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/152010


More information about the llvm-commits mailing list