[llvm] 6387d38 - [AArch64][SME] Add an instruction mapping for SME pseudos

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 16 03:53:41 PST 2023


Author: Kerry McLaughlin
Date: 2023-01-16T11:52:44Z
New Revision: 6387d3896629e225100d91b5827ea67882496eb4

URL: https://github.com/llvm/llvm-project/commit/6387d3896629e225100d91b5827ea67882496eb4
DIFF: https://github.com/llvm/llvm-project/commit/6387d3896629e225100d91b5827ea67882496eb4.diff

LOG: [AArch64][SME] Add an instruction mapping for SME pseudos

Adds an instruction mapping to SMEInstrFormats which matches SME
pseudos with the real instructions they are transformed to.
A new flag is also added to AArch64Inst (SMEMatrixType), which is
used to indicate the base register required when emitting many
of the SME instructions.

This reduces the number of pseudos handled by the switch statement
in EmitInstrWithCustomInserter.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D136856

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.h
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
    llvm/lib/Target/AArch64/SMEInstrFormats.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 93aac68e0997..511c103b53da 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2676,35 +2676,16 @@ AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const {
 }
 
 MachineBasicBlock *
-AArch64TargetLowering::EmitMopa(unsigned Opc, unsigned BaseReg,
-                                MachineInstr &MI, MachineBasicBlock *BB) const {
+AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg,
+                                   MachineInstr &MI,
+                                   MachineBasicBlock *BB) const {
   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
   MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
 
   MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
   MIB.addReg(BaseReg + MI.getOperand(0).getImm());
-  MIB.add(MI.getOperand(1)); // pn
-  MIB.add(MI.getOperand(2)); // pm
-  MIB.add(MI.getOperand(3)); // zn
-  MIB.add(MI.getOperand(4)); // zm
-
-  MI.eraseFromParent(); // The pseudo is gone now.
-  return BB;
-}
-
-MachineBasicBlock *
-AArch64TargetLowering::EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg,
-                                              MachineInstr &MI,
-                                              MachineBasicBlock *BB) const {
-  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
-  MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
-
-  MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
-  MIB.addReg(BaseReg + MI.getOperand(0).getImm());
-  MIB.add(MI.getOperand(1)); // Slice index register
-  MIB.add(MI.getOperand(2)); // Slice index offset
-  MIB.add(MI.getOperand(3)); // pg
-  MIB.add(MI.getOperand(4)); // zn
+  for (unsigned I = 1; I < MI.getNumOperands(); ++I)
+    MIB.add(MI.getOperand(I));
 
   MI.eraseFromParent(); // The pseudo is gone now.
   return BB;
@@ -2727,25 +2708,28 @@ AArch64TargetLowering::EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const {
   return BB;
 }
 
-MachineBasicBlock *
-AArch64TargetLowering::EmitAddVectorToTile(unsigned Opc, unsigned BaseReg,
-                                           MachineInstr &MI,
-                                           MachineBasicBlock *BB) const {
-  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
-  MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
-
-  MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
-  MIB.addReg(BaseReg + MI.getOperand(0).getImm());
-  MIB.add(MI.getOperand(1)); // pn
-  MIB.add(MI.getOperand(2)); // pm
-  MIB.add(MI.getOperand(3)); // zn
-
-  MI.eraseFromParent(); // The pseudo is gone now.
-  return BB;
-}
-
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
+
+  int SMEOrigInstr = AArch64::getSMEPseudoMap(MI.getOpcode());
+  if (SMEOrigInstr != -1) {
+    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+    uint64_t SMEMatrixType =
+        TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask;
+    switch (SMEMatrixType) {
+    case (AArch64::SMEMatrixTileB):
+      return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB);
+    case (AArch64::SMEMatrixTileH):
+      return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB);
+    case (AArch64::SMEMatrixTileS):
+      return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB);
+    case (AArch64::SMEMatrixTileD):
+      return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB);
+    case (AArch64::SMEMatrixTileQ):
+      return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB);
+    }
+  }
+
   switch (MI.getOpcode()) {
   default:
 #ifndef NDEBUG
@@ -2795,94 +2779,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB);
   case AArch64::LDR_ZA_PSEUDO:
     return EmitFill(MI, BB);
-  case AArch64::BFMOPA_MPPZZ_PSEUDO:
-    return EmitMopa(AArch64::BFMOPA_MPPZZ, AArch64::ZAS0, MI, BB);
-  case AArch64::BFMOPS_MPPZZ_PSEUDO:
-    return EmitMopa(AArch64::BFMOPS_MPPZZ, AArch64::ZAS0, MI, BB);
-  case AArch64::FMOPAL_MPPZZ_PSEUDO:
-    return EmitMopa(AArch64::FMOPAL_MPPZZ, AArch64::ZAS0, MI, BB);
-  case AArch64::FMOPSL_MPPZZ_PSEUDO:
-    return EmitMopa(AArch64::FMOPSL_MPPZZ, AArch64::ZAS0, MI, BB);
-  case AArch64::FMOPA_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::FMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::FMOPS_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::FMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::FMOPA_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::FMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::FMOPS_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::FMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::SMOPA_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::SMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::SMOPS_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::SMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::UMOPA_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::UMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::UMOPS_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::UMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::SUMOPA_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::SUMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::SUMOPS_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::SUMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::USMOPA_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::USMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::USMOPS_MPPZZ_S_PSEUDO:
-    return EmitMopa(AArch64::USMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::SMOPA_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::SMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::SMOPS_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::SMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::UMOPA_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::UMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::UMOPS_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::UMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::SUMOPA_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::SUMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::SUMOPS_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::SUMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::USMOPA_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::USMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::USMOPS_MPPZZ_D_PSEUDO:
-    return EmitMopa(AArch64::USMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::INSERT_MXIPZ_H_PSEUDO_B:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_B, AArch64::ZAB0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_H_PSEUDO_H:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_H, AArch64::ZAH0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_H_PSEUDO_S:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_S, AArch64::ZAS0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_H_PSEUDO_D:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_D, AArch64::ZAD0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_H_PSEUDO_Q:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_Q, AArch64::ZAQ0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_V_PSEUDO_B:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_B, AArch64::ZAB0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_V_PSEUDO_H:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_H, AArch64::ZAH0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_V_PSEUDO_S:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_S, AArch64::ZAS0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_V_PSEUDO_D:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_D, AArch64::ZAD0, MI,
-                                  BB);
-  case AArch64::INSERT_MXIPZ_V_PSEUDO_Q:
-    return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_Q, AArch64::ZAQ0, MI,
-                                  BB);
   case AArch64::ZERO_M_PSEUDO:
     return EmitZero(MI, BB);
-  case AArch64::ADDHA_MPPZ_PSEUDO_S:
-    return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::ADDVA_MPPZ_PSEUDO_S:
-    return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_S, AArch64::ZAS0, MI, BB);
-  case AArch64::ADDHA_MPPZ_PSEUDO_D:
-    return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_D, AArch64::ZAD0, MI, BB);
-  case AArch64::ADDVA_MPPZ_PSEUDO_D:
-    return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_D, AArch64::ZAD0, MI, BB);
   }
 }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9cf99b308121..febb1161b370 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -595,15 +595,9 @@ class AArch64TargetLowering : public TargetLowering {
                                   MachineInstr &MI,
                                   MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const;
-  MachineBasicBlock *EmitMopa(unsigned Opc, unsigned BaseReg, MachineInstr &MI,
-                              MachineBasicBlock *BB) const;
-  MachineBasicBlock *EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg,
-                                            MachineInstr &MI,
-                                            MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg,
+                                 MachineInstr &MI, MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const;
-  MachineBasicBlock *EmitAddVectorToTile(unsigned Opc, unsigned BaseReg,
-                                         MachineInstr &MI,
-                                         MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 0a24896433a0..91179aa8046e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -45,6 +45,17 @@ def FalseLanesNone  : FalseLanesEnum<0>;
 def FalseLanesZero  : FalseLanesEnum<1>;
 def FalseLanesUndef : FalseLanesEnum<2>;
 
+class SMEMatrixTypeEnum<bits<3> val> {
+  bits<3> Value = val;
+}
+def SMEMatrixNone  : SMEMatrixTypeEnum<0>;
+def SMEMatrixTileB : SMEMatrixTypeEnum<1>;
+def SMEMatrixTileH : SMEMatrixTypeEnum<2>;
+def SMEMatrixTileS : SMEMatrixTypeEnum<3>;
+def SMEMatrixTileD : SMEMatrixTypeEnum<4>;
+def SMEMatrixTileQ : SMEMatrixTypeEnum<5>;
+def SMEMatrixArray : SMEMatrixTypeEnum<6>;
+
 // AArch64 Instruction Format
 class AArch64Inst<Format f, string cstr> : Instruction {
   field bits<32> Inst; // Instruction encoding.
@@ -65,16 +76,18 @@ class AArch64Inst<Format f, string cstr> : Instruction {
   bit isPTestLike = 0;
   FalseLanesEnum FalseLanes = FalseLanesNone;
   DestructiveInstTypeEnum DestructiveInstType = NotDestructive;
+  SMEMatrixTypeEnum SMEMatrixType = SMEMatrixNone;
   ElementSizeEnum ElementSize = ElementSizeNone;
 
-  let TSFlags{10}  = isPTestLike;
-  let TSFlags{9}   = isWhile;
-  let TSFlags{8-7} = FalseLanes.Value;
-  let TSFlags{6-3} = DestructiveInstType.Value;
-  let TSFlags{2-0} = ElementSize.Value;
+  let TSFlags{13-11} = SMEMatrixType.Value;
+  let TSFlags{10}    = isPTestLike;
+  let TSFlags{9}     = isWhile;
+  let TSFlags{8-7}   = FalseLanes.Value;
+  let TSFlags{6-3}   = DestructiveInstType.Value;
+  let TSFlags{2-0}   = ElementSize.Value;
 
-  let Pattern     = [];
-  let Constraints = cstr;
+  let Pattern       = [];
+  let Constraints   = cstr;
 }
 
 class InstSubst<string Asm, dag Result, bit EmitPriority = 0>

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index ec60927f3aa7..caf9421eb001 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -539,10 +539,11 @@ static inline unsigned getPACOpcodeForKey(AArch64PACKey::ID K, bool Zero) {
 }
 
 // struct TSFlags {
-#define TSFLAG_ELEMENT_SIZE_TYPE(X)      (X)       // 3-bits
-#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits
-#define TSFLAG_FALSE_LANE_TYPE(X)       ((X) << 7) // 2-bits
-#define TSFLAG_INSTR_FLAGS(X)           ((X) << 9) // 2-bits
+#define TSFLAG_ELEMENT_SIZE_TYPE(X)      (X)        // 3-bits
+#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3)  // 4-bits
+#define TSFLAG_FALSE_LANE_TYPE(X)       ((X) << 7)  // 2-bits
+#define TSFLAG_INSTR_FLAGS(X)           ((X) << 9)  // 2-bits
+#define TSFLAG_SME_MATRIX_TYPE(X)       ((X) << 11) // 3-bits
 // }
 
 namespace AArch64 {
@@ -580,14 +581,28 @@ enum FalseLaneType {
 static const uint64_t InstrFlagIsWhile     = TSFLAG_INSTR_FLAGS(0x1);
 static const uint64_t InstrFlagIsPTestLike = TSFLAG_INSTR_FLAGS(0x2);
 
+enum SMEMatrixType {
+  SMEMatrixTypeMask = TSFLAG_SME_MATRIX_TYPE(0x7),
+  SMEMatrixNone     = TSFLAG_SME_MATRIX_TYPE(0x0),
+  SMEMatrixTileB    = TSFLAG_SME_MATRIX_TYPE(0x1),
+  SMEMatrixTileH    = TSFLAG_SME_MATRIX_TYPE(0x2),
+  SMEMatrixTileS    = TSFLAG_SME_MATRIX_TYPE(0x3),
+  SMEMatrixTileD    = TSFLAG_SME_MATRIX_TYPE(0x4),
+  SMEMatrixTileQ    = TSFLAG_SME_MATRIX_TYPE(0x5),
+  SMEMatrixArray    = TSFLAG_SME_MATRIX_TYPE(0x6),
+};
+
 #undef TSFLAG_ELEMENT_SIZE_TYPE
 #undef TSFLAG_DESTRUCTIVE_INST_TYPE
 #undef TSFLAG_FALSE_LANE_TYPE
 #undef TSFLAG_INSTR_FLAGS
+#undef TSFLAG_SME_MATRIX_TYPE
 
 int getSVEPseudoMap(uint16_t Opcode);
 int getSVERevInstr(uint16_t Opcode);
 int getSVENonRevInstr(uint16_t Opcode);
+
+int getSMEPseudoMap(uint16_t Opcode);
 }
 
 } // end namespace llvm

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 468db595e1b2..df08e5ec758c 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -49,15 +49,15 @@ def RDSVLI_XI  : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>
 def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
 def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
 
-def ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha">;
-def ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva">;
+defm ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha", int_aarch64_sme_addha>;
+defm ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva", int_aarch64_sme_addva>;
 
 def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
 }
 
 let Predicates = [HasSMEI16I64] in {
-def ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha">;
-def ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva">;
+defm ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha", int_aarch64_sme_addha>;
+defm ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva", int_aarch64_sme_addva>;
 }
 
 let Predicates = [HasSME] in {

diff  --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index e5416030bf22..3556d7fcfbe0 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -25,17 +25,35 @@ def tileslice128 : ComplexPattern<i32 , 2, "SelectSMETileSlice<0>", []>; // nop
 def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;
 
 //===----------------------------------------------------------------------===//
-// SME Outer Products
+// SME Pseudo Classes
 //===----------------------------------------------------------------------===//
 
-class sme_outer_product_pseudo<ZPRRegOp zpr_ty>
+def getSMEPseudoMap : InstrMapping {
+  let FilterClass = "SMEPseudo2Instr";
+  let RowFields = ["PseudoName"];
+  let ColFields = ["IsInstr"];
+  let KeyCol = ["0"];
+  let ValueCols = [["1"]];
+}
+
+class SMEPseudo2Instr<string name, bit instr> {
+  string PseudoName = name;
+  bit IsInstr = instr;
+}
+
+class sme_outer_product_pseudo<ZPRRegOp zpr_ty, SMEMatrixTypeEnum za_flag>
     : Pseudo<(outs), (ins i32imm:$tile, PPR3bAny:$pn, PPR3bAny:$pm,
                           zpr_ty:$zn, zpr_ty:$zm), []>,
       Sched<[]> {
   // Translated to the actual instructions in AArch64ISelLowering.cpp
+  let SMEMatrixType = za_flag;
   let usesCustomInserter = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SME Outer Products
+//===----------------------------------------------------------------------===//
+
 class sme_fp_outer_product_inst<bit S, bits<2> sz, bit op, MatrixTileOperand za_ty,
                                 ZPRRegOp zpr_ty, string mnemonic>
     : I<(outs za_ty:$ZAda),
@@ -62,13 +80,13 @@ class sme_fp_outer_product_inst<bit S, bits<2> sz, bit op, MatrixTileOperand za_
 }
 
 multiclass sme_outer_product_fp32<bit S, string mnemonic, SDPatternOperator op> {
-  def NAME : sme_fp_outer_product_inst<S, 0b00, 0b0, TileOp32, ZPR32, mnemonic> {
+  def NAME : sme_fp_outer_product_inst<S, 0b00, 0b0, TileOp32, ZPR32, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<2> ZAda;
     let Inst{1-0} = ZAda;
     let Inst{2}   = 0b0;
   }
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR32>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR32, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm),
                 (nxv4f32 ZPR32:$zn), (nxv4f32 ZPR32:$zm)),
@@ -76,12 +94,12 @@ multiclass sme_outer_product_fp32<bit S, string mnemonic, SDPatternOperator op>
 }
 
 multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> {
-  def NAME : sme_fp_outer_product_inst<S, 0b10, 0b0, TileOp64, ZPR64, mnemonic> {
+  def NAME : sme_fp_outer_product_inst<S, 0b10, 0b0, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<3> ZAda;
     let Inst{2-0} = ZAda;
   }
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR64>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR64, SMEMatrixTileD>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm),
                 (nxv2f64 ZPR64:$zn), (nxv2f64 ZPR64:$zm)),
@@ -126,13 +144,13 @@ class sme_int_outer_product_inst<bits<3> opc, bit sz, bit sme2,
 multiclass sme_int_outer_product_i32<bits<3> opc, string mnemonic,
                                      SDPatternOperator op> {
   def NAME : sme_int_outer_product_inst<opc, 0b0, 0b0,  TileOp32,
-                                        ZPR8, mnemonic> {
+                                        ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<2> ZAda;
     let Inst{1-0} = ZAda;
     let Inst{2}   = 0b0;
   }
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_3:$tile, (nxv16i1 PPR3bAny:$pn), (nxv16i1 PPR3bAny:$pm),
                 (nxv16i8 ZPR8:$zn), (nxv16i8 ZPR8:$zm)),
@@ -142,12 +160,12 @@ multiclass sme_int_outer_product_i32<bits<3> opc, string mnemonic,
 multiclass sme_int_outer_product_i64<bits<3> opc, string mnemonic,
                                      SDPatternOperator op> {
   def NAME : sme_int_outer_product_inst<opc, 0b1, 0b0, TileOp64,
-                                        ZPR16, mnemonic> {
+                                        ZPR16, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<3> ZAda;
     let Inst{2-0} = ZAda;
   }
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16, SMEMatrixTileD>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_7:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm),
                 (nxv8i16 ZPR16:$zn), (nxv8i16 ZPR16:$zm)),
@@ -182,9 +200,9 @@ class sme_outer_product_widening_inst<bits<3> opc, ZPRRegOp zpr_ty, string mnemo
 }
 
 multiclass sme_bf16_outer_product<bits<3> opc, string mnemonic, SDPatternOperator op> {
-  def NAME : sme_outer_product_widening_inst<opc, ZPR16, mnemonic>;
+  def NAME : sme_outer_product_widening_inst<opc, ZPR16, mnemonic>, SMEPseudo2Instr<NAME, 1>;
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm),
                 (nxv8bf16 ZPR16:$zn), (nxv8bf16 ZPR16:$zm)),
@@ -192,9 +210,9 @@ multiclass sme_bf16_outer_product<bits<3> opc, string mnemonic, SDPatternOperato
 }
 
 multiclass sme_f16_outer_product<bits<3> opc, string mnemonic, SDPatternOperator op> {
-  def NAME : sme_outer_product_widening_inst<opc, ZPR16, mnemonic>;
+  def NAME : sme_outer_product_widening_inst<opc, ZPR16, mnemonic>, SMEPseudo2Instr<NAME, 1>;
 
-  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16>;
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR16, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
 
   def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm),
                 (nxv8f16 ZPR16:$zn), (nxv8f16 ZPR16:$zm)),
@@ -226,51 +244,42 @@ class sme_add_vector_to_tile_inst<bit op, bit V, MatrixTileOperand tile_ty,
   let Constraints = "$ZAda = $_ZAda";
 }
 
-class sme_add_vector_to_tile_u32<bit V, string mnemonic>
-    : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic> {
-  bits<2> ZAda;
-  let Inst{2}   = 0b0;
-  let Inst{1-0} = ZAda;
-}
-
-class sme_add_vector_to_tile_u64<bit V, string mnemonic>
-    : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic> {
-  bits<3> ZAda;
-  let Inst{2-0} = ZAda;
-}
-
-class sme_add_vector_to_tile_pseudo<ZPRRegOp zpr_ty>
+class sme_add_vector_to_tile_pseudo<ZPRRegOp zpr_ty, SMEMatrixTypeEnum za_flag>
     : Pseudo<(outs),
              (ins i32imm:$tile, PPR3bAny:$Pn, PPR3bAny:$Pm, zpr_ty:$Zn), []>,
       Sched<[]> {
   // Translated to the actual instructions in AArch64ISelLowering.cpp
+  let SMEMatrixType = za_flag;
   let usesCustomInserter = 1;
 }
 
-def ADDHA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo<ZPR32>;
-def ADDVA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo<ZPR32>;
+multiclass sme_add_vector_to_tile_u32<bit V, string mnemonic, SDPatternOperator op> {
+    def NAME : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+  bits<2> ZAda;
+  let Inst{2}   = 0b0;
+  let Inst{1-0} = ZAda;
+  }
+
+  def _PSEUDO_S : sme_add_vector_to_tile_pseudo<ZPR32, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
 
-def : Pat<(int_aarch64_sme_addha
-            timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm),
-            (nxv4i32 ZPR32:$zn)),
-          (ADDHA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>;
-def : Pat<(int_aarch64_sme_addva
-            timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm),
+  def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm),
             (nxv4i32 ZPR32:$zn)),
-          (ADDVA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>;
+          (!cast<Instruction>(NAME # _PSEUDO_S) timm32_0_3:$tile, $pn, $pm, $zn)>;
+}
+
+multiclass sme_add_vector_to_tile_u64<bit V, string mnemonic, SDPatternOperator op> {
+    def NAME : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+  bits<3> ZAda;
+  let Inst{2-0} = ZAda;
+  }
 
-let Predicates = [HasSMEI16I64] in {
-def ADDHA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo<ZPR64>;
-def ADDVA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo<ZPR64>;
+  def _PSEUDO_D : sme_add_vector_to_tile_pseudo<ZPR64, SMEMatrixTileD>, SMEPseudo2Instr<NAME, 0>;
 
-def : Pat<(int_aarch64_sme_addha
-            timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm),
-            (nxv2i64 ZPR64:$zn)),
-          (ADDHA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>;
-def : Pat<(int_aarch64_sme_addva
-            timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm),
-            (nxv2i64 ZPR64:$zn)),
-          (ADDVA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>;
+  let Predicates = [HasSMEI16I64] in {
+  def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm),
+                (nxv2i64 ZPR64:$zn)),
+            (!cast<Instruction>(NAME # _PSEUDO_D) timm32_0_7:$tile, $pn, $pm, $zn)>;
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -711,24 +720,27 @@ multiclass sme_vector_to_tile_patterns<Instruction inst, ValueType zpr_vt,
   }
 }
 
-class sme_mova_insert_pseudo
+class sme_mova_insert_pseudo<SMEMatrixTypeEnum za_flag>
     : Pseudo<(outs), (ins i32imm:$tile, MatrixIndexGPR32Op12_15:$idx,
                           i32imm:$imm, PPR3bAny:$pg, ZPRAny:$zn), []>,
       Sched<[]> {
   // Translated to the actual instructions in AArch64ISelLowering.cpp
+  let SMEMatrixType = za_flag;
   let usesCustomInserter = 1;
 }
 
 multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
   def _B : sme_vector_to_tile_inst<0b0, 0b00, !if(is_col, TileVectorOpV8,
                                                           TileVectorOpH8),
-                                   is_col, sme_elm_idx0_15, ZPR8, mnemonic> {
+                                   is_col, sme_elm_idx0_15, ZPR8, mnemonic>,
+                                   SMEPseudo2Instr<NAME # _B, 1> {
     bits<4> imm;
     let Inst{3-0} = imm;
   }
   def _H : sme_vector_to_tile_inst<0b0, 0b01, !if(is_col, TileVectorOpV16,
                                                           TileVectorOpH16),
-                                   is_col, sme_elm_idx0_7, ZPR16, mnemonic> {
+                                   is_col, sme_elm_idx0_7, ZPR16, mnemonic>,
+                                   SMEPseudo2Instr<NAME # _H, 1> {
     bits<1> ZAd;
     bits<3> imm;
     let Inst{3}   = ZAd;
@@ -736,7 +748,8 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
   }
   def _S : sme_vector_to_tile_inst<0b0, 0b10, !if(is_col, TileVectorOpV32,
                                                           TileVectorOpH32),
-                                   is_col, sme_elm_idx0_3, ZPR32, mnemonic> {
+                                   is_col, sme_elm_idx0_3, ZPR32, mnemonic>,
+                                   SMEPseudo2Instr<NAME # _S, 1> {
     bits<2> ZAd;
     bits<2> imm;
     let Inst{3-2} = ZAd;
@@ -744,7 +757,8 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
   }
   def _D : sme_vector_to_tile_inst<0b0, 0b11, !if(is_col, TileVectorOpV64,
                                                           TileVectorOpH64),
-                                   is_col, sme_elm_idx0_1, ZPR64, mnemonic> {
+                                   is_col, sme_elm_idx0_1, ZPR64, mnemonic>,
+                                   SMEPseudo2Instr<NAME # _D, 1> {
     bits<3> ZAd;
     bits<1> imm;
     let Inst{3-1} = ZAd;
@@ -752,7 +766,8 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
   }
   def _Q : sme_vector_to_tile_inst<0b1, 0b11, !if(is_col, TileVectorOpV128,
                                                           TileVectorOpH128),
-                                   is_col, sme_elm_idx0_0, ZPR128, mnemonic> {
+                                   is_col, sme_elm_idx0_0, ZPR128, mnemonic>,
+                                   SMEPseudo2Instr<NAME # _Q, 1> {
     bits<4> ZAd;
     bits<1> imm;
     let Inst{3-0} = ZAd;
@@ -760,11 +775,11 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
 
   // Pseudo instructions for lowering intrinsics, using immediates instead of
   // tile registers.
-  def _PSEUDO_B : sme_mova_insert_pseudo;
-  def _PSEUDO_H : sme_mova_insert_pseudo;
-  def _PSEUDO_S : sme_mova_insert_pseudo;
-  def _PSEUDO_D : sme_mova_insert_pseudo;
-  def _PSEUDO_Q : sme_mova_insert_pseudo;
+  def _PSEUDO_B : sme_mova_insert_pseudo<SMEMatrixTileB>, SMEPseudo2Instr<NAME # _B, 0>;
+  def _PSEUDO_H : sme_mova_insert_pseudo<SMEMatrixTileH>, SMEPseudo2Instr<NAME # _H, 0>;
+  def _PSEUDO_S : sme_mova_insert_pseudo<SMEMatrixTileS>, SMEPseudo2Instr<NAME # _S, 0>;
+  def _PSEUDO_D : sme_mova_insert_pseudo<SMEMatrixTileD>, SMEPseudo2Instr<NAME # _D, 0>;
+  def _PSEUDO_Q : sme_mova_insert_pseudo<SMEMatrixTileQ>, SMEPseudo2Instr<NAME # _Q, 0>;
 
   defm : sme_vector_to_tile_aliases<!cast<Instruction>(NAME # _B),
                                     !if(is_col, TileVectorOpV8,


        


More information about the llvm-commits mailing list