[llvm] b966978 - [GlobalISel][NFC] Introduce a GVecReduce wrapper class and a minor refactor.
Amara Emerson via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 12 14:13:34 PDT 2023
Author: Amara Emerson
Date: 2023-08-12T13:55:08-07:00
New Revision: b9669789c363912478cc35fcfe305ca26157ea72
URL: https://github.com/llvm/llvm-project/commit/b9669789c363912478cc35fcfe305ca26157ea72
DIFF: https://github.com/llvm/llvm-project/commit/b9669789c363912478cc35fcfe305ca26157ea72.diff
LOG: [GlobalISel][NFC] Introduce a GVecReduce wrapper class and a minor refactor.
Added:
Modified:
llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 20dd73cc7ddb5c..ef2fca2e8ef481 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -400,6 +400,82 @@ class GIntrinsic final : public GenericMachineInstr {
}
};
+// Represents a (non-sequential) vector reduction operation.
+class GVecReduce : public GenericMachineInstr {
+public:
+ static bool classof(const MachineInstr *MI) {
+ switch (MI->getOpcode()) {
+ case TargetOpcode::G_VECREDUCE_FADD:
+ case TargetOpcode::G_VECREDUCE_FMUL:
+ case TargetOpcode::G_VECREDUCE_FMAX:
+ case TargetOpcode::G_VECREDUCE_FMIN:
+ case TargetOpcode::G_VECREDUCE_ADD:
+ case TargetOpcode::G_VECREDUCE_MUL:
+ case TargetOpcode::G_VECREDUCE_AND:
+ case TargetOpcode::G_VECREDUCE_OR:
+ case TargetOpcode::G_VECREDUCE_XOR:
+ case TargetOpcode::G_VECREDUCE_SMAX:
+ case TargetOpcode::G_VECREDUCE_SMIN:
+ case TargetOpcode::G_VECREDUCE_UMAX:
+ case TargetOpcode::G_VECREDUCE_UMIN:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ /// Get the opcode for the equivalent scalar operation for this reduction.
+ /// E.g. for G_VECREDUCE_FADD, this returns G_FADD.
+ unsigned getScalarOpcForReduction() {
+ unsigned ScalarOpc;
+ switch (getOpcode()) {
+ case TargetOpcode::G_VECREDUCE_FADD:
+ ScalarOpc = TargetOpcode::G_FADD;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMUL:
+ ScalarOpc = TargetOpcode::G_FMUL;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMAX:
+ ScalarOpc = TargetOpcode::G_FMAXNUM;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMIN:
+ ScalarOpc = TargetOpcode::G_FMINNUM;
+ break;
+ case TargetOpcode::G_VECREDUCE_ADD:
+ ScalarOpc = TargetOpcode::G_ADD;
+ break;
+ case TargetOpcode::G_VECREDUCE_MUL:
+ ScalarOpc = TargetOpcode::G_MUL;
+ break;
+ case TargetOpcode::G_VECREDUCE_AND:
+ ScalarOpc = TargetOpcode::G_AND;
+ break;
+ case TargetOpcode::G_VECREDUCE_OR:
+ ScalarOpc = TargetOpcode::G_OR;
+ break;
+ case TargetOpcode::G_VECREDUCE_XOR:
+ ScalarOpc = TargetOpcode::G_XOR;
+ break;
+ case TargetOpcode::G_VECREDUCE_SMAX:
+ ScalarOpc = TargetOpcode::G_SMAX;
+ break;
+ case TargetOpcode::G_VECREDUCE_SMIN:
+ ScalarOpc = TargetOpcode::G_SMIN;
+ break;
+ case TargetOpcode::G_VECREDUCE_UMAX:
+ ScalarOpc = TargetOpcode::G_UMAX;
+ break;
+ case TargetOpcode::G_VECREDUCE_UMIN:
+ ScalarOpc = TargetOpcode::G_UMIN;
+ break;
+ default:
+ llvm_unreachable("Unhandled reduction");
+ }
+ return ScalarOpc;
+ }
+};
+
+
} // namespace llvm
#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index a4dc9d99d772b3..023f72b84fd7c8 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -4428,73 +4428,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
return Legalized;
}
-static unsigned getScalarOpcForReduction(unsigned Opc) {
- unsigned ScalarOpc;
- switch (Opc) {
- case TargetOpcode::G_VECREDUCE_FADD:
- ScalarOpc = TargetOpcode::G_FADD;
- break;
- case TargetOpcode::G_VECREDUCE_FMUL:
- ScalarOpc = TargetOpcode::G_FMUL;
- break;
- case TargetOpcode::G_VECREDUCE_FMAX:
- ScalarOpc = TargetOpcode::G_FMAXNUM;
- break;
- case TargetOpcode::G_VECREDUCE_FMIN:
- ScalarOpc = TargetOpcode::G_FMINNUM;
- break;
- case TargetOpcode::G_VECREDUCE_ADD:
- ScalarOpc = TargetOpcode::G_ADD;
- break;
- case TargetOpcode::G_VECREDUCE_MUL:
- ScalarOpc = TargetOpcode::G_MUL;
- break;
- case TargetOpcode::G_VECREDUCE_AND:
- ScalarOpc = TargetOpcode::G_AND;
- break;
- case TargetOpcode::G_VECREDUCE_OR:
- ScalarOpc = TargetOpcode::G_OR;
- break;
- case TargetOpcode::G_VECREDUCE_XOR:
- ScalarOpc = TargetOpcode::G_XOR;
- break;
- case TargetOpcode::G_VECREDUCE_SMAX:
- ScalarOpc = TargetOpcode::G_SMAX;
- break;
- case TargetOpcode::G_VECREDUCE_SMIN:
- ScalarOpc = TargetOpcode::G_SMIN;
- break;
- case TargetOpcode::G_VECREDUCE_UMAX:
- ScalarOpc = TargetOpcode::G_UMAX;
- break;
- case TargetOpcode::G_VECREDUCE_UMIN:
- ScalarOpc = TargetOpcode::G_UMIN;
- break;
- default:
- llvm_unreachable("Unhandled reduction");
- }
- return ScalarOpc;
-}
-
LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
- unsigned Opc = MI.getOpcode();
- assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
- Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
- "Sequential reductions not expected");
+ auto &RdxMI = cast<GVecReduce>(MI);
if (TypeIdx != 1)
return UnableToLegalize;
// The semantics of the normal non-sequential reductions allow us to freely
// re-associate the operation.
- auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
+ auto [DstReg, DstTy, SrcReg, SrcTy] = RdxMI.getFirst2RegLLTs();
if (NarrowTy.isVector() &&
(SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
return UnableToLegalize;
- unsigned ScalarOpc = getScalarOpcForReduction(Opc);
+ unsigned ScalarOpc = RdxMI.getScalarOpcForReduction();
SmallVector<Register> SplitSrcs;
// If NarrowTy is a scalar then we're being asked to scalarize.
const unsigned NumParts =
@@ -4539,10 +4488,10 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
SmallVector<Register> PartialReductions;
for (unsigned Part = 0; Part < NumParts; ++Part) {
PartialReductions.push_back(
- MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+ MIRBuilder.buildInstr(RdxMI.getOpcode(), {DstTy}, {SplitSrcs[Part]})
+ .getReg(0));
}
-
// If the types involved are powers of 2, we can generate intermediate vector
// ops, before generating a final reduction operation.
if (isPowerOf2_32(SrcTy.getNumElements()) &&
More information about the llvm-commits
mailing list