[llvm] b966978 - [GlobalISel][NFC] Introduce a GVecReduce wrapper class and a minor refactor.

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 12 14:13:34 PDT 2023


Author: Amara Emerson
Date: 2023-08-12T13:55:08-07:00
New Revision: b9669789c363912478cc35fcfe305ca26157ea72

URL: https://github.com/llvm/llvm-project/commit/b9669789c363912478cc35fcfe305ca26157ea72
DIFF: https://github.com/llvm/llvm-project/commit/b9669789c363912478cc35fcfe305ca26157ea72.diff

LOG: [GlobalISel][NFC] Introduce a GVecReduce wrapper class and a minor refactor.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
    llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 20dd73cc7ddb5c..ef2fca2e8ef481 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -400,6 +400,82 @@ class GIntrinsic final : public GenericMachineInstr {
   }
 };
 
+// Represents a (non-sequential) vector reduction operation.
+class GVecReduce : public GenericMachineInstr {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_VECREDUCE_FADD:
+    case TargetOpcode::G_VECREDUCE_FMUL:
+    case TargetOpcode::G_VECREDUCE_FMAX:
+    case TargetOpcode::G_VECREDUCE_FMIN:
+    case TargetOpcode::G_VECREDUCE_ADD:
+    case TargetOpcode::G_VECREDUCE_MUL:
+    case TargetOpcode::G_VECREDUCE_AND:
+    case TargetOpcode::G_VECREDUCE_OR:
+    case TargetOpcode::G_VECREDUCE_XOR:
+    case TargetOpcode::G_VECREDUCE_SMAX:
+    case TargetOpcode::G_VECREDUCE_SMIN:
+    case TargetOpcode::G_VECREDUCE_UMAX:
+    case TargetOpcode::G_VECREDUCE_UMIN:
+      return true;
+    default:
+      return false;
+    }
+  }
+
+  /// Get the opcode for the equivalent scalar operation for this reduction.
+  /// E.g. for G_VECREDUCE_FADD, this returns G_FADD.
+  unsigned getScalarOpcForReduction() {
+    unsigned ScalarOpc;
+    switch (getOpcode()) {
+    case TargetOpcode::G_VECREDUCE_FADD:
+      ScalarOpc = TargetOpcode::G_FADD;
+      break;
+    case TargetOpcode::G_VECREDUCE_FMUL:
+      ScalarOpc = TargetOpcode::G_FMUL;
+      break;
+    case TargetOpcode::G_VECREDUCE_FMAX:
+      ScalarOpc = TargetOpcode::G_FMAXNUM;
+      break;
+    case TargetOpcode::G_VECREDUCE_FMIN:
+      ScalarOpc = TargetOpcode::G_FMINNUM;
+      break;
+    case TargetOpcode::G_VECREDUCE_ADD:
+      ScalarOpc = TargetOpcode::G_ADD;
+      break;
+    case TargetOpcode::G_VECREDUCE_MUL:
+      ScalarOpc = TargetOpcode::G_MUL;
+      break;
+    case TargetOpcode::G_VECREDUCE_AND:
+      ScalarOpc = TargetOpcode::G_AND;
+      break;
+    case TargetOpcode::G_VECREDUCE_OR:
+      ScalarOpc = TargetOpcode::G_OR;
+      break;
+    case TargetOpcode::G_VECREDUCE_XOR:
+      ScalarOpc = TargetOpcode::G_XOR;
+      break;
+    case TargetOpcode::G_VECREDUCE_SMAX:
+      ScalarOpc = TargetOpcode::G_SMAX;
+      break;
+    case TargetOpcode::G_VECREDUCE_SMIN:
+      ScalarOpc = TargetOpcode::G_SMIN;
+      break;
+    case TargetOpcode::G_VECREDUCE_UMAX:
+      ScalarOpc = TargetOpcode::G_UMAX;
+      break;
+    case TargetOpcode::G_VECREDUCE_UMIN:
+      ScalarOpc = TargetOpcode::G_UMIN;
+      break;
+    default:
+      llvm_unreachable("Unhandled reduction");
+    }
+    return ScalarOpc;
+  }
+};
+
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H

diff  --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index a4dc9d99d772b3..023f72b84fd7c8 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -4428,73 +4428,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
   return Legalized;
 }
 
-static unsigned getScalarOpcForReduction(unsigned Opc) {
-  unsigned ScalarOpc;
-  switch (Opc) {
-  case TargetOpcode::G_VECREDUCE_FADD:
-    ScalarOpc = TargetOpcode::G_FADD;
-    break;
-  case TargetOpcode::G_VECREDUCE_FMUL:
-    ScalarOpc = TargetOpcode::G_FMUL;
-    break;
-  case TargetOpcode::G_VECREDUCE_FMAX:
-    ScalarOpc = TargetOpcode::G_FMAXNUM;
-    break;
-  case TargetOpcode::G_VECREDUCE_FMIN:
-    ScalarOpc = TargetOpcode::G_FMINNUM;
-    break;
-  case TargetOpcode::G_VECREDUCE_ADD:
-    ScalarOpc = TargetOpcode::G_ADD;
-    break;
-  case TargetOpcode::G_VECREDUCE_MUL:
-    ScalarOpc = TargetOpcode::G_MUL;
-    break;
-  case TargetOpcode::G_VECREDUCE_AND:
-    ScalarOpc = TargetOpcode::G_AND;
-    break;
-  case TargetOpcode::G_VECREDUCE_OR:
-    ScalarOpc = TargetOpcode::G_OR;
-    break;
-  case TargetOpcode::G_VECREDUCE_XOR:
-    ScalarOpc = TargetOpcode::G_XOR;
-    break;
-  case TargetOpcode::G_VECREDUCE_SMAX:
-    ScalarOpc = TargetOpcode::G_SMAX;
-    break;
-  case TargetOpcode::G_VECREDUCE_SMIN:
-    ScalarOpc = TargetOpcode::G_SMIN;
-    break;
-  case TargetOpcode::G_VECREDUCE_UMAX:
-    ScalarOpc = TargetOpcode::G_UMAX;
-    break;
-  case TargetOpcode::G_VECREDUCE_UMIN:
-    ScalarOpc = TargetOpcode::G_UMIN;
-    break;
-  default:
-    llvm_unreachable("Unhandled reduction");
-  }
-  return ScalarOpc;
-}
-
 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
-  unsigned Opc = MI.getOpcode();
-  assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
-         Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
-         "Sequential reductions not expected");
+  auto &RdxMI = cast<GVecReduce>(MI);
 
   if (TypeIdx != 1)
     return UnableToLegalize;
 
   // The semantics of the normal non-sequential reductions allow us to freely
   // re-associate the operation.
-  auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
+  auto [DstReg, DstTy, SrcReg, SrcTy] = RdxMI.getFirst2RegLLTs();
 
   if (NarrowTy.isVector() &&
       (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
     return UnableToLegalize;
 
-  unsigned ScalarOpc = getScalarOpcForReduction(Opc);
+  unsigned ScalarOpc = RdxMI.getScalarOpcForReduction();
   SmallVector<Register> SplitSrcs;
   // If NarrowTy is a scalar then we're being asked to scalarize.
   const unsigned NumParts =
@@ -4539,10 +4488,10 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
   SmallVector<Register> PartialReductions;
   for (unsigned Part = 0; Part < NumParts; ++Part) {
     PartialReductions.push_back(
-        MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+        MIRBuilder.buildInstr(RdxMI.getOpcode(), {DstTy}, {SplitSrcs[Part]})
+            .getReg(0));
   }
 
-
   // If the types involved are powers of 2, we can generate intermediate vector
   // ops, before generating a final reduction operation.
   if (isPowerOf2_32(SrcTy.getNumElements()) &&


        


More information about the llvm-commits mailing list