[llvm] b26e6a8 - [GlobalISel] Add `GITypeOf` special type (#66079)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 31 01:57:15 PDT 2023
Author: Pierre van Houtryve
Date: 2023-10-31T09:57:10+01:00
New Revision: b26e6a8eb57da6bc0f6d968a7ff87be0f3862683
URL: https://github.com/llvm/llvm-project/commit/b26e6a8eb57da6bc0f6d968a7ff87be0f3862683
DIFF: https://github.com/llvm/llvm-project/commit/b26e6a8eb57da6bc0f6d968a7ff87be0f3862683.diff
LOG: [GlobalISel] Add `GITypeOf` special type (#66079)
Allows creating a register/immediate that uses the same type as a
matched operand.
Added:
llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
Modified:
llvm/docs/GlobalISel/MIRPatterns.rst
llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
llvm/include/llvm/Target/GlobalISel/Combine.td
llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
llvm/utils/TableGen/GlobalISelMatchTable.cpp
llvm/utils/TableGen/GlobalISelMatchTable.h
Removed:
################################################################################
diff --git a/llvm/docs/GlobalISel/MIRPatterns.rst b/llvm/docs/GlobalISel/MIRPatterns.rst
index fa70311f48572de..a3883b14b3e0bd6 100644
--- a/llvm/docs/GlobalISel/MIRPatterns.rst
+++ b/llvm/docs/GlobalISel/MIRPatterns.rst
@@ -101,6 +101,48 @@ pattern, you can try naming your patterns to see exactly where the issue is.
// using $x again here copies operand 1 from G_AND into the new inst.
(apply (COPY $root, $x))
+Types
+-----
+
+ValueType
+~~~~~~~~~
+
+Subclasses of ``ValueType`` are valid types, e.g. ``i32``.
+
+GITypeOf
+~~~~~~~~
+
+``GITypeOf<"$x">`` is a ``GISpecialType`` that allows for the creation of a
+register or immediate with the same type as another (register) operand.
+
+Operand:
+
+* An operand name as a string, prefixed by ``$``.
+
+Semantics:
+
+* Can only appear in an 'apply' pattern.
+* The operand name used must appear in the 'match' pattern of the
+ same ``GICombineRule``.
+
+.. code-block:: text
+ :caption: Example: Immediate
+
+ def mul_by_neg_one: GICombineRule <
+ (defs root:$root),
+ (match (G_MUL $dst, $x, -1)),
+ (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+ >;
+
+.. code-block:: text
+ :caption: Example: Temp Reg
+
+ def Test0 : GICombineRule<
+ (defs root:$dst),
+ (match (G_FMUL $dst, $src, -1)),
+ (apply (G_FSUB $dst, $src, $tmp),
+ (G_FNEG GITypeOf<"$dst">:$tmp, $src))>;
+
Builtin Operations
------------------
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 65299e852574bd1..ba72a3b71ffd70b 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -405,9 +405,6 @@ class CombinerHelper {
void applyCombineTruncOfShift(MachineInstr &MI,
std::pair<MachineInstr *, LLT> &MatchInfo);
- /// Transform G_MUL(x, -1) to G_SUB(0, x)
- void applyCombineMulByNegativeOne(MachineInstr &MI);
-
/// Return true if any explicit use operand on \p MI is defined by a
/// G_IMPLICIT_DEF.
bool matchAnyExplicitUseIsUndef(MachineInstr &MI);
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
index 209f80c6d6d2877..6fcd9d09e1863cc 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
@@ -275,6 +275,12 @@ enum {
/// - StoreIdx - Store location in RecordedOperands.
GIM_RecordNamedOperand,
+ /// Records an operand's register type into the set of temporary types.
+ /// - InsnID - Instruction ID
+ /// - OpIdx - Operand index
+ /// - TempTypeIdx - Temp Type Index, always negative.
+ GIM_RecordRegType,
+
/// Fail the current try-block, or completely fail to match if there is no
/// current try-block.
GIM_Reject,
@@ -522,6 +528,10 @@ class GIMatchTableExecutor {
/// list. Currently such predicates don't have more then 3 arguments.
std::array<const MachineOperand *, 3> RecordedOperands;
+ /// Types extracted from an instruction's operand.
+ /// Whenever a type index is negative, we look here instead.
+ SmallVector<LLT, 4> RecordedTypes;
+
MatcherState(unsigned MaxRenderers);
};
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index fb03d5ec0bc89a9..32e2f21d775f303 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -92,6 +92,14 @@ bool GIMatchTableExecutor::executeMatchTable(
return true;
};
+ // If the index is >= 0, it's an index in the type objects generated by
+ // TableGen. If the index is <0, it's an index in the recorded types object.
+ auto getTypeFromIdx = [&](int64_t Idx) -> LLT {
+ if (Idx >= 0)
+ return ExecInfo.TypeObjects[Idx];
+ return State.RecordedTypes[1 - Idx];
+ };
+
while (true) {
assert(CurrentIdx != ~0u && "Invalid MatchTable index");
int64_t MatcherOpcode = MatchTable[CurrentIdx++];
@@ -627,8 +635,7 @@ bool GIMatchTableExecutor::executeMatchTable(
<< "), TypeID=" << TypeID << ")\n");
assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
MachineOperand &MO = State.MIs[InsnID]->getOperand(OpIdx);
- if (!MO.isReg() ||
- MRI.getType(MO.getReg()) != ExecInfo.TypeObjects[TypeID]) {
+ if (!MO.isReg() || MRI.getType(MO.getReg()) != getTypeFromIdx(TypeID)) {
if (handleReject() == RejectAndGiveUp)
return false;
}
@@ -679,6 +686,25 @@ bool GIMatchTableExecutor::executeMatchTable(
State.RecordedOperands[StoreIdx] = &State.MIs[InsnID]->getOperand(OpIdx);
break;
}
+ case GIM_RecordRegType: {
+ int64_t InsnID = MatchTable[CurrentIdx++];
+ int64_t OpIdx = MatchTable[CurrentIdx++];
+ int64_t TypeIdx = MatchTable[CurrentIdx++];
+
+ DEBUG_WITH_TYPE(TgtExecutor::getName(),
+ dbgs() << CurrentIdx << ": GIM_RecordRegType(MIs["
+ << InsnID << "]->getOperand(" << OpIdx
+ << "), TypeIdx=" << TypeIdx << ")\n");
+ assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
+ assert(TypeIdx <= 0 && "Temp types always have negative indexes!");
+ // Indexes start at -1.
+ TypeIdx = 1 - TypeIdx;
+ const auto &Op = State.MIs[InsnID]->getOperand(OpIdx);
+ if (State.RecordedTypes.size() <= (uint64_t)TypeIdx)
+ State.RecordedTypes.resize(TypeIdx + 1, LLT());
+ State.RecordedTypes[TypeIdx] = MRI.getType(Op.getReg());
+ break;
+ }
case GIM_CheckRegBankForClass: {
int64_t InsnID = MatchTable[CurrentIdx++];
int64_t OpIdx = MatchTable[CurrentIdx++];
@@ -1275,7 +1301,7 @@ bool GIMatchTableExecutor::executeMatchTable(
int64_t TypeID = MatchTable[CurrentIdx++];
State.TempRegisters[TempRegID] =
- MRI.createGenericVirtualRegister(ExecInfo.TypeObjects[TypeID]);
+ MRI.createGenericVirtualRegister(getTypeFromIdx(TypeID));
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": TempRegs[" << TempRegID
<< "] = GIR_MakeTempReg(" << TypeID << ")\n");
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index bb8223ba3486a8d..63c485a5a6c6070 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -110,6 +110,24 @@ class GICombinePatFrag<dag outs, dag ins, list<dag> alts> {
list<dag> Alternatives = alts;
}
+//===----------------------------------------------------------------------===//
+// Pattern Special Types
+//===----------------------------------------------------------------------===//
+
+class GISpecialType;
+
+// In an apply pattern, GITypeOf can be used to set the type of a new temporary
+// register to match the type of a matched register.
+//
+// This can only be used on temporary registers defined by the apply pattern.
+//
+// TODO: Make this work in matchers as well?
+//
+// FIXME: Syntax is very ugly.
+class GITypeOf<string opName> : GISpecialType {
+ string OpName = opName;
+}
+
//===----------------------------------------------------------------------===//
// Pattern Builtins
//===----------------------------------------------------------------------===//
@@ -776,10 +794,9 @@ def trunc_shift: GICombineRule <
// Transform (mul x, -1) -> (sub 0, x)
def mul_by_neg_one: GICombineRule <
- (defs root:$root),
- (match (wip_match_opcode G_MUL):$root,
- [{ return Helper.matchConstantOp(${root}->getOperand(2), -1); }]),
- (apply [{ Helper.applyCombineMulByNegativeOne(*${root}); }])
+ (defs root:$dst),
+ (match (G_MUL $dst, $x, -1)),
+ (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
>;
// Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 3c2b5f490ccb871..51c268ab77c2220 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -2351,18 +2351,6 @@ void CombinerHelper::applyCombineExtOfExt(
}
}
-void CombinerHelper::applyCombineMulByNegativeOne(MachineInstr &MI) {
- assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
- Register DstReg = MI.getOperand(0).getReg();
- Register SrcReg = MI.getOperand(1).getReg();
- LLT DstTy = MRI.getType(DstReg);
-
- Builder.setInstrAndDebugLoc(MI);
- Builder.buildSub(DstReg, Builder.buildConstant(DstTy, 0), SrcReg,
- MI.getFlags());
- MI.eraseFromParent();
-}
-
bool CombinerHelper::matchCombineTruncOfExt(
MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
new file mode 100644
index 000000000000000..496d86aeef2d10a
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
@@ -0,0 +1,49 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN: -combiners=MyCombiner %s | \
+// RUN: FileCheck %s
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+def Test0 : GICombineRule<
+ (defs root:$dst),
+ (match (G_MUL $dst, $src, -1)),
+ (apply (G_SUB $dst, (GITypeOf<"$src"> 0), $tmp),
+ (G_CONSTANT GITypeOf<"$dst">:$tmp, (GITypeOf<"$src"> 42)))>;
+
+// CHECK: const int64_t *GenMyCombiner::getMatchTable() const {
+// CHECK-NEXT: constexpr static int64_t MatchTable0[] = {
+// CHECK-NEXT: GIM_Try, /*On fail goto*//*Label 0*/ 57, // Rule ID 0 //
+// CHECK-NEXT: GIM_CheckSimplePredicate, GICXXPred_Simple_IsRule0Enabled,
+// CHECK-NEXT: GIM_CheckOpcode, /*MI*/0, TargetOpcode::G_MUL,
+// CHECK-NEXT: // MIs[0] dst
+// CHECK-NEXT: GIM_RecordRegType, /*MI*/0, /*Op*/0, /*TempTypeIdx*/-1,
+// CHECK-NEXT: // MIs[0] src
+// CHECK-NEXT: GIM_RecordRegType, /*MI*/0, /*Op*/1, /*TempTypeIdx*/-2,
+// CHECK-NEXT: // MIs[0] Operand 2
+// CHECK-NEXT: GIM_CheckConstantInt, /*MI*/0, /*Op*/2, -1,
+// CHECK-NEXT: GIR_MakeTempReg, /*TempRegID*/1, /*TypeID*/-2,
+// CHECK-NEXT: GIR_BuildConstant, /*TempRegID*/1, /*Val*/0,
+// CHECK-NEXT: GIR_MakeTempReg, /*TempRegID*/0, /*TypeID*/-1,
+// CHECK-NEXT: // Combiner Rule #0: Test0
+// CHECK-NEXT: GIR_BuildMI, /*InsnID*/0, /*Opcode*/TargetOpcode::G_CONSTANT,
+// CHECK-NEXT: GIR_AddTempRegister, /*InsnID*/0, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT: GIR_AddCImm, /*InsnID*/0, /*Type*/-2, /*Imm*/42,
+// CHECK-NEXT: GIR_EraseFromParent, /*InsnID*/0,
+// CHECK-NEXT: GIR_BuildMI, /*InsnID*/1, /*Opcode*/TargetOpcode::G_SUB,
+// CHECK-NEXT: GIR_Copy, /*NewInsnID*/1, /*OldInsnID*/0, /*OpIdx*/0, // dst
+// CHECK-NEXT: GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/1, /*TempRegFlags*/0,
+// CHECK-NEXT: GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT: GIR_Done,
+// CHECK-NEXT: // Label 0: @57
+// CHECK-NEXT: GIM_Reject,
+// CHECK-NEXT: };
+// CHECK-NEXT: return MatchTable0;
+// CHECK-NEXT: }
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+ Test0
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
index c871e603e4e05aa..4769bed97240125 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
@@ -79,7 +79,33 @@ def PatFragTest0 : GICombineRule<
(match (FooPF $dst)),
(apply (COPY $dst, (i32 0)))>;
+
+// CHECK: (CombineRule name:TypeOfProp id:2 root:x
+// CHECK-NEXT: (MatchPats
+// CHECK-NEXT: <match_root>__TypeOfProp_match_0:(CodeGenInstructionPattern G_ZEXT operands:[<def>$x, $y])
+// CHECK-NEXT: )
+// CHECK-NEXT: (ApplyPats
+// CHECK-NEXT: <apply_root>__TypeOfProp_apply_0:(CodeGenInstructionPattern G_ANYEXT operands:[<def>$x, GITypeOf<$y>:$tmp])
+// CHECK-NEXT: __TypeOfProp_apply_1:(CodeGenInstructionPattern G_ANYEXT operands:[<def>GITypeOf<$y>:$tmp, $y])
+// CHECK-NEXT: )
+// CHECK-NEXT: (OperandTable MatchPats
+// CHECK-NEXT: x -> __TypeOfProp_match_0
+// CHECK-NEXT: y -> <live-in>
+// CHECK-NEXT: )
+// CHECK-NEXT: (OperandTable ApplyPats
+// CHECK-NEXT: tmp -> __TypeOfProp_apply_1
+// CHECK-NEXT: x -> __TypeOfProp_apply_0
+// CHECK-NEXT: y -> <live-in>
+// CHECK-NEXT: )
+// CHECK-NEXT: )
+def TypeOfProp : GICombineRule<
+ (defs root:$x),
+ (match (G_ZEXT $x, $y)),
+ (apply (G_ANYEXT $x, GITypeOf<"$y">:$tmp),
+ (G_ANYEXT $tmp, $y))>;
+
def MyCombiner: GICombiner<"GenMyCombiner", [
InstTest0,
- PatFragTest0
+ PatFragTest0,
+ TypeOfProp
]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
index bc75b15233b5519..fd41a7d1d72417e 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
@@ -297,6 +297,28 @@ def VariadicsOutTest : GICombineRule<
(apply (COPY $a, (i32 0)),
(COPY $b, (i32 0)))>;
+// CHECK: (CombineRule name:TypeOfTest id:10 root:dst
+// CHECK-NEXT: (MatchPats
+// CHECK-NEXT: <match_root>__TypeOfTest_match_0:(CodeGenInstructionPattern COPY operands:[<def>$dst, $tmp])
+// CHECK-NEXT: __TypeOfTest_match_1:(CodeGenInstructionPattern G_ZEXT operands:[<def>$tmp, $src])
+// CHECK-NEXT: )
+// CHECK-NEXT: (ApplyPats
+// CHECK-NEXT: <apply_root>__TypeOfTest_apply_0:(CodeGenInstructionPattern G_MUL operands:[<def>$dst, (GITypeOf<$src> 0), (GITypeOf<$dst> -1)])
+// CHECK-NEXT: )
+// CHECK-NEXT: (OperandTable MatchPats
+// CHECK-NEXT: dst -> __TypeOfTest_match_0
+// CHECK-NEXT: src -> <live-in>
+// CHECK-NEXT: tmp -> __TypeOfTest_match_1
+// CHECK-NEXT: )
+// CHECK-NEXT: (OperandTable ApplyPats
+// CHECK-NEXT: dst -> __TypeOfTest_apply_0
+// CHECK-NEXT: )
+// CHECK-NEXT: )
+def TypeOfTest : GICombineRule<
+ (defs root:$dst),
+ (match (COPY $dst, $tmp),
+ (G_ZEXT $tmp, $src)),
+ (apply (G_MUL $dst, (GITypeOf<"$src"> 0), (GITypeOf<"$dst"> -1)))>;
def MyCombiner: GICombiner<"GenMyCombiner", [
WipOpcodeTest0,
@@ -308,5 +330,6 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
PatFragTest0,
PatFragTest1,
VariadicsInTest,
- VariadicsOutTest
+ VariadicsOutTest,
+ TypeOfTest
]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
new file mode 100644
index 000000000000000..6040d6def449766
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -0,0 +1,72 @@
+// RUN: not llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN: -combiners=MyCombiner %s 2>&1| \
+// RUN: FileCheck %s -implicit-check-not=error:
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+def NoDollarSign : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, $src)),
+ (apply (G_ANYEXT $dst, (GITypeOf<"unknown"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'unknown' ('GITypeOf<$unknown>') does not refer to a matched operand!
+def UnknownOperand : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, $src)),
+ (apply (G_ANYEXT $dst, (GITypeOf<"$unknown"> 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__UseInMatch_match_0' has type 'GITypeOf<$dst>'
+def UseInMatch : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, (GITypeOf<"$dst"> 0))),
+ (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: GISpecialType is not supported in GICombinePatFrag
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: operand 1 of '__PFWithTypeOF_alt0_pattern_0' has type 'GITypeOf<$dst>
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Could not parse GICombinePatFrag 'PFWithTypeOF'
+def PFWithTypeOF: GICombinePatFrag<
+ (outs $dst), (ins),
+ [(pattern (G_ANYEXT $dst, (GITypeOf<"$dst"> 0)))]>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(PFWithTypeOF ?:$dst)'
+def UseInPF: GICombineRule<
+ (defs root:$dst),
+ (match (PFWithTypeOF $dst)),
+ (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__InferredUseInMatch_match_0' has type 'GITypeOf<$dst>'
+def InferredUseInMatch : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, $src)),
+ (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+def InferenceConflict : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, i32:$src)),
+ (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'tmp' ('GITypeOf<$tmp>') does not refer to a matched operand!
+def TypeOfApplyTmp : GICombineRule<
+ (defs root:$dst),
+ (match (G_ZEXT $dst, $src)),
+ (apply (G_ANYEXT $dst, i32:$tmp),
+ (G_ANYEXT $tmp, (GITypeOf<"$tmp"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse one or more rules
+def MyCombiner: GICombiner<"GenMyCombiner", [
+ NoDollarSign,
+ UnknownOperand,
+ UseInMatch,
+ UseInPF,
+ InferredUseInMatch,
+ InferenceConflict,
+ TypeOfApplyTmp
+]>;
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 7992cb4362a1718..0c7b33a7b9d889d 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -73,6 +73,8 @@ constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
constexpr StringLiteral PatFragClassName = "GICombinePatFrag";
constexpr StringLiteral BuiltinInstClassName = "GIBuiltinInst";
+constexpr StringLiteral SpecialTyClassName = "GISpecialType";
+constexpr StringLiteral TypeOfClassName = "GITypeOf";
std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
@@ -123,11 +125,6 @@ template <typename Container> auto values(Container &&C) {
return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
}
-LLTCodeGen getLLTCodeGenFromRecord(const Record *Ty) {
- assert(Ty->isSubClassOf("ValueType"));
- return LLTCodeGen(*MVTToLLT(getValueType(Ty)));
-}
-
//===- MatchData Handling -------------------------------------------------===//
/// Represents MatchData defined by the match stage and required by the apply
@@ -292,6 +289,116 @@ class CXXPredicateCode {
CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXMatchCode;
CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
+//===- PatternType --------------------------------------------------------===//
+
+/// Represent the type of a Pattern Operand.
+///
+/// Types have two form:
+/// - LLTs, which are straightforward.
+/// - Special types, e.g. GITypeOf
+class PatternType {
+public:
+ PatternType() = default;
+ PatternType(const Record *R) : R(R) {}
+
+ bool isValidType() const { return !R || isLLT() || isSpecial(); }
+
+ bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
+ bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
+ bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+
+ StringRef getTypeOfOpName() const;
+ LLTCodeGen getLLTCodeGen() const;
+
+ bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
+
+ LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
+
+ explicit operator bool() const { return R != nullptr; }
+
+ bool operator==(const PatternType &Other) const;
+ bool operator!=(const PatternType &Other) const { return !operator==(Other); }
+
+ std::string str() const;
+
+private:
+ StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+
+ const Record *R = nullptr;
+};
+
+StringRef PatternType::getTypeOfOpName() const {
+ assert(isTypeOf());
+ StringRef Name = getRawOpName();
+ Name.consume_front("$");
+ return Name;
+}
+
+LLTCodeGen PatternType::getLLTCodeGen() const {
+ assert(isLLT());
+ return *MVTToLLT(getValueType(R));
+}
+
+LLTCodeGenOrTempType
+PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
+ assert(isValidType());
+
+ if (isLLT())
+ return getLLTCodeGen();
+
+ assert(isTypeOf());
+ auto &OM = RM.getOperandMatcher(getTypeOfOpName());
+ return OM.getTempTypeIdx(RM);
+}
+
+bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
+ if (!isTypeOf())
+ return true;
+
+ auto RawOpName = getRawOpName();
+ if (RawOpName.starts_with("$"))
+ return true;
+
+ PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
+ TypeOfClassName +
+ ": expected '$' followed by an operand name");
+ return false;
+}
+
+bool PatternType::operator==(const PatternType &Other) const {
+ if (R == Other.R) {
+ if (R && R->getName() != Other.R->getName()) {
+ dbgs() << "Same ptr but: " << R->getName() << " and "
+ << Other.R->getName() << "?\n";
+ assert(false);
+ }
+ return true;
+ }
+
+ if (isTypeOf() && Other.isTypeOf())
+ return getTypeOfOpName() == Other.getTypeOfOpName();
+
+ return false;
+}
+
+std::string PatternType::str() const {
+ if (!R)
+ return "";
+
+ if (!isValidType())
+ return "<invalid>";
+
+ if (isLLT())
+ return R->getName().str();
+
+ assert(isSpecial());
+
+ if (isTypeOf())
+ return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+
+ llvm_unreachable("Unknown type!");
+}
+
//===- Pattern Base Class -------------------------------------------------===//
/// Base class for all patterns that can be written in an `apply`, `match` or
@@ -499,13 +606,15 @@ class InstructionOperand {
public:
using IntImmTy = int64_t;
- InstructionOperand(IntImmTy Imm, StringRef Name, const Record *Type)
+ InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
: Value(Imm), Name(insertStrRef(Name)), Type(Type) {
- assert(!Type || Type->isSubClassOf("ValueType"));
+ assert(Type.isValidType());
}
- InstructionOperand(StringRef Name, const Record *Type)
- : Name(insertStrRef(Name)), Type(Type) {}
+ InstructionOperand(StringRef Name, PatternType Type)
+ : Name(insertStrRef(Name)), Type(Type) {
+ assert(Type.isValidType());
+ }
bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
@@ -527,11 +636,12 @@ class InstructionOperand {
void setIsDef(bool Value = true) { Def = Value; }
bool isDef() const { return Def; }
- void setType(const Record *R) {
- assert((!Type || (Type == R)) && "Overwriting type!");
- Type = R;
+ void setType(PatternType NewType) {
+ assert((!Type || (Type == NewType)) && "Overwriting type!");
+ assert(NewType.isValidType());
+ Type = NewType;
}
- const Record *getType() const { return Type; }
+ PatternType getType() const { return Type; }
std::string describe() const {
if (!hasImmValue())
@@ -547,11 +657,11 @@ class InstructionOperand {
OS << "<def>";
bool NeedsColon = true;
- if (const Record *Ty = getType()) {
+ if (Type) {
if (hasImmValue())
- OS << "(" << Ty->getName() << " " << getImmValue() << ")";
+ OS << "(" << Type.str() << " " << getImmValue() << ")";
else
- OS << Ty->getName();
+ OS << Type.str();
} else if (hasImmValue())
OS << getImmValue();
else
@@ -566,7 +676,7 @@ class InstructionOperand {
private:
std::optional<int64_t> Value;
StringRef Name;
- const Record *Type = nullptr;
+ PatternType Type;
bool Def = false;
};
@@ -622,6 +732,10 @@ class InstructionPattern : public Pattern {
virtual StringRef getInstName() const = 0;
+ /// Diagnoses all uses of special types in this Pattern and returns true if at
+ /// least one diagnostic was emitted.
+ bool diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc, Twine Msg) const;
+
void reportUnreachable(ArrayRef<SMLoc> Locs) const;
virtual bool checkSemantics(ArrayRef<SMLoc> Loc);
@@ -633,6 +747,20 @@ class InstructionPattern : public Pattern {
SmallVector<InstructionOperand, 4> Operands;
};
+bool InstructionPattern::diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc,
+ Twine Msg) const {
+ bool HasDiag = false;
+ for (const auto &[Idx, Op] : enumerate(operands())) {
+ if (Op.getType().isSpecial()) {
+ PrintError(Loc, Msg);
+ PrintNote(Loc, "operand " + Twine(Idx) + " of '" + getName() +
+ "' has type '" + Op.getType().str() + "'");
+ HasDiag = true;
+ }
+ }
+ return HasDiag;
+}
+
void InstructionPattern::reportUnreachable(ArrayRef<SMLoc> Locs) const {
PrintError(Locs, "pattern '" + getName() + "' ('" + getInstName() +
"') is unreachable from the pattern root!");
@@ -829,17 +957,20 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
/// It infers the type of each operand, check it's consistent with the known
/// type of the operand, and then sets all of the types in all operands in
/// setAllOperandTypes.
+///
+/// It also handles verifying correctness of special types.
class OperandTypeChecker {
public:
OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
- bool check(InstructionPattern *P);
+ bool check(InstructionPattern *P,
+ std::function<bool(const PatternType &)> VerifyTypeOfOperand);
void setAllOperandTypes();
private:
struct OpTypeInfo {
- const Record *Type = nullptr;
+ PatternType Type;
InstructionPattern *TypeSrc = nullptr;
};
@@ -849,16 +980,26 @@ class OperandTypeChecker {
SmallVector<InstructionPattern *, 16> Pats;
};
-bool OperandTypeChecker::check(InstructionPattern *P) {
+bool OperandTypeChecker::check(
+ InstructionPattern *P,
+ std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
Pats.push_back(P);
- for (auto &Op : P->named_operands()) {
- const Record *Ty = Op.getType();
+ for (auto &Op : P->operands()) {
+ const auto Ty = Op.getType();
if (!Ty)
continue;
- auto &Info = Types[Op.getOperandName()];
+ if (!Ty.checkSemantics(DiagLoc))
+ return false;
+
+ if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
+ return false;
+ if (!Op.isNamedOperand())
+ continue;
+
+ auto &Info = Types[Op.getOperandName()];
if (!Info.Type) {
Info.Type = Ty;
Info.TypeSrc = P;
@@ -868,9 +1009,9 @@ bool OperandTypeChecker::check(InstructionPattern *P) {
if (Info.Type != Ty) {
PrintError(DiagLoc, "conflicting types for operand '" +
Op.getOperandName() + "': first seen with '" +
- Info.Type->getName() + "' in '" +
+ Info.Type.str() + "' in '" +
Info.TypeSrc->getName() + ", now seen with '" +
- Ty->getName() + "' in '" + P->getName() + "'");
+ Ty.str() + "' in '" + P->getName() + "'");
return false;
}
}
@@ -1058,7 +1199,12 @@ bool PatFrag::checkSemantics() {
PatFragClassName);
return false;
case Pattern::K_CXX:
+ continue;
case Pattern::K_CodeGenInstruction:
+ if (cast<CodeGenInstructionPattern>(Pat.get())->diagnoseAllSpecialTypes(
+ Def.getLoc(), SpecialTyClassName + " is not supported in " +
+ PatFragClassName))
+ return false;
continue;
case Pattern::K_PatFrag:
// TODO: It's just that the emitter doesn't handle it but technically
@@ -1142,12 +1288,16 @@ bool PatFrag::checkSemantics() {
// TODO: find unused params
+ const auto CheckTypeOf = [&](const PatternType &) -> bool {
+ llvm_unreachable("GITypeOf should have been rejected earlier!");
+ };
+
// Now, typecheck all alternatives.
for (auto &Alt : Alts) {
OperandTypeChecker OTC(Def.getLoc());
for (auto &Pat : Alt.Pats) {
if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
- if (!OTC.check(IP))
+ if (!OTC.check(IP, CheckTypeOf))
return false;
}
}
@@ -1954,21 +2104,49 @@ bool CombineRuleBuilder::hasEraseRoot() const {
bool CombineRuleBuilder::typecheckPatterns() {
OperandTypeChecker OTC(RuleDef.getLoc());
+ const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
+ // We'll reject those after we're done inferring
+ return true;
+ };
+
for (auto &Pat : values(MatchPats)) {
if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
- if (!OTC.check(IP))
+ if (!OTC.check(IP, CheckMatchTypeOf))
return false;
}
}
+ const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
+ // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+ const auto OpName = Ty.getTypeOfOpName();
+ if (MatchOpTable.lookup(OpName).Found)
+ return true;
+
+ PrintError("'" + OpName + "' ('" + Ty.str() +
+ "') does not refer to a matched operand!");
+ return false;
+ };
+
for (auto &Pat : values(ApplyPats)) {
if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
- if (!OTC.check(IP))
+ if (!OTC.check(IP, CheckApplyTypeOf))
return false;
}
}
OTC.setAllOperandTypes();
+
+ // Always check this after in case inference adds some special types to the
+ // match patterns.
+ for (auto &Pat : values(MatchPats)) {
+ if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
+ if (IP->diagnoseAllSpecialTypes(
+ RuleDef.getLoc(),
+ SpecialTyClassName + " is not supported in 'match' patterns")) {
+ return false;
+ }
+ }
+ }
return true;
}
@@ -2461,10 +2639,12 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
if (DagOp->getNumArgs() != 1)
return ParseErr();
- Record *ImmTy = DagOp->getOperatorAsDef(RuleDef.getLoc());
- if (!ImmTy->isSubClassOf("ValueType")) {
+ const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
+ PatternType ImmTy(TyDef);
+ if (!ImmTy.isValidType()) {
PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
- "', '" + ImmTy->getName() + "' is not a ValueType!");
+ "', '" + TyDef->getName() + "' is not a ValueType or " +
+ SpecialTyClassName);
return false;
}
@@ -2491,12 +2671,13 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
return false;
}
const Record *Def = DefI->getDef();
- if (!Def->isSubClassOf("ValueType")) {
+ PatternType Ty(Def);
+ if (!Ty.isValidType()) {
PrintError("invalid operand type: '" + Def->getName() +
"' is not a ValueType");
return false;
}
- IP.addOperand(OpName->getAsUnquotedString(), Def);
+ IP.addOperand(OpName->getAsUnquotedString(), Ty);
return true;
}
@@ -2823,8 +3004,8 @@ bool CombineRuleBuilder::emitPatFragMatchPattern(
StringRef PFName = PF.getName();
PrintWarning("impossible type constraints: operand " + Twine(PIdx) +
" of '" + PFP.getName() + "' has type '" +
- ArgOp.getType()->getName() + "', but '" + PFName +
- "' constrains it to '" + O.getType()->getName() + "'");
+ ArgOp.getType().str() + "', but '" + PFName +
+ "' constrains it to '" + O.getType().str() + "'");
if (ArgOp.isNamedOperand())
PrintNote("operand " + Twine(PIdx) + " of '" + PFP.getName() +
"' is '" + ArgOp.getOperandName() + "'");
@@ -3055,17 +3236,18 @@ bool CombineRuleBuilder::emitInstructionApplyPattern(
// This is a brand new register.
TempRegID = M.allocateTempRegID();
OperandToTempRegID[OpName] = TempRegID;
- const Record *Ty = Op.getType();
+ const auto Ty = Op.getType();
if (!Ty) {
PrintError("def of a new register '" + OpName +
"' in the apply patterns must have a type");
return false;
}
+
declareTempRegExpansion(CE, TempRegID, OpName);
// Always insert the action at the beginning, otherwise we may end up
// using the temp reg before it's available.
M.insertAction<MakeTempRegisterAction>(
- M.actions_begin(), getLLTCodeGenFromRecord(Ty), TempRegID);
+ M.actions_begin(), Ty.getLLTCodeGenOrTempType(M), TempRegID);
}
DstMI.addRenderer<TempRegRenderer>(TempRegID);
@@ -3088,7 +3270,7 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
// G_CONSTANT is a special case and needs a CImm though so this is likely a
// mistake.
const bool isGConstant = P.is("G_CONSTANT");
- const Record *Ty = O.getType();
+ const auto Ty = O.getType();
if (!Ty) {
if (isGConstant) {
PrintError("'G_CONSTANT' immediate must be typed!");
@@ -3101,16 +3283,17 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
return true;
}
- LLTCodeGen LLT = getLLTCodeGenFromRecord(Ty);
+ auto ImmTy = Ty.getLLTCodeGenOrTempType(M);
+
if (isGConstant) {
- DstMI.addRenderer<ImmRenderer>(O.getImmValue(), LLT);
+ DstMI.addRenderer<ImmRenderer>(O.getImmValue(), ImmTy);
return true;
}
unsigned TempRegID = M.allocateTempRegID();
// Ensure MakeTempReg & the BuildConstantAction occur at the beginning.
- auto InsertIt =
- M.insertAction<MakeTempRegisterAction>(M.actions_begin(), LLT, TempRegID);
+ auto InsertIt = M.insertAction<MakeTempRegisterAction>(M.actions_begin(),
+ ImmTy, TempRegID);
M.insertAction<BuildConstantAction>(++InsertIt, TempRegID, O.getImmValue());
DstMI.addRenderer<TempRegRenderer>(TempRegID);
return true;
@@ -3227,8 +3410,14 @@ bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
// Always emit a check for unnamed operands.
if (OpName.empty() ||
!M.getOperandMatcher(OpName).contains<LLTOperandMatcher>()) {
- if (const Record *Ty = RemappedO.getType())
- OM.addPredicate<LLTOperandMatcher>(getLLTCodeGenFromRecord(Ty));
+ if (const auto Ty = RemappedO.getType()) {
+ // TODO: We could support GITypeOf here on the condition that the
+ // OperandMatcher exists already. Though it's clunky to make this work
+ // and isn't all that useful so it's just rejected in typecheckPatterns
+ // at this time.
+ assert(Ty.isLLT() && "Only LLTs are supported in match patterns!");
+ OM.addPredicate<LLTOperandMatcher>(Ty.getLLTCodeGen());
+ }
}
// Stop here if the operand is a def, or if it had no name.
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
index 9a4a375f34bdb91..6ec85269e6e20d0 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.cpp
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
@@ -822,6 +822,15 @@ const OperandMatcher &RuleMatcher::getPhysRegOperandMatcher(Record *Reg) const {
return *I->second;
}
+OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) {
+ const auto &I = DefinedOperands.find(Name);
+
+ if (I == DefinedOperands.end())
+ PrintFatalError(SrcLoc, "Operand " + Name + " was not declared in matcher");
+
+ return *I->second;
+}
+
const OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) const {
const auto &I = DefinedOperands.find(Name);
@@ -1081,6 +1090,17 @@ void RecordNamedOperandMatcher::emitPredicateOpcodes(MatchTable &Table,
<< MatchTable::Comment("Name : " + Name) << MatchTable::LineBreak;
}
+//===- RecordRegisterType ------------------------------------------===//
+
+void RecordRegisterType::emitPredicateOpcodes(MatchTable &Table,
+ RuleMatcher &Rule) const {
+ assert(Idx < 0 && "Temp types always have negative indexes!");
+ Table << MatchTable::Opcode("GIM_RecordRegType") << MatchTable::Comment("MI")
+ << MatchTable::IntValue(InsnVarID) << MatchTable::Comment("Op")
+ << MatchTable::IntValue(OpIdx) << MatchTable::Comment("TempTypeIdx")
+ << MatchTable::IntValue(Idx) << MatchTable::LineBreak;
+}
+
//===- ComplexPatternOperandMatcher ---------------------------------------===//
void ComplexPatternOperandMatcher::emitPredicateOpcodes(
@@ -1196,6 +1216,18 @@ std::string OperandMatcher::getOperandExpr(unsigned InsnVarID) const {
unsigned OperandMatcher::getInsnVarID() const { return Insn.getInsnVarID(); }
+TempTypeIdx OperandMatcher::getTempTypeIdx(RuleMatcher &Rule) {
+ if (TTIdx >= 0) {
+ // Temp type index not assigned yet, so assign one and add the necessary
+ // predicate.
+ TTIdx = Rule.getNextTempTypeIdx();
+ assert(TTIdx < 0);
+ addPredicate<RecordRegisterType>(TTIdx);
+ return TTIdx;
+ }
+ return TTIdx;
+}
+
void OperandMatcher::emitPredicateOpcodes(MatchTable &Table,
RuleMatcher &Rule) {
if (!Optimized) {
@@ -2092,9 +2124,7 @@ void MakeTempRegisterAction::emitActionOpcodes(MatchTable &Table,
RuleMatcher &Rule) const {
Table << MatchTable::Opcode("GIR_MakeTempReg")
<< MatchTable::Comment("TempRegID") << MatchTable::IntValue(TempRegID)
- << MatchTable::Comment("TypeID")
- << MatchTable::NamedValue(Ty.getCxxEnumValue())
- << MatchTable::LineBreak;
+ << MatchTable::Comment("TypeID") << Ty << MatchTable::LineBreak;
}
} // namespace gi
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.h b/llvm/utils/TableGen/GlobalISelMatchTable.h
index 5608bab482bfd34..364f2a1ec725d53 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.h
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.h
@@ -273,6 +273,40 @@ extern std::set<LLTCodeGen> KnownTypes;
/// MVTs that don't map cleanly to an LLT (e.g., iPTR, *any, ...).
std::optional<LLTCodeGen> MVTToLLT(MVT::SimpleValueType SVT);
+using TempTypeIdx = int64_t;
+class LLTCodeGenOrTempType {
+public:
+ LLTCodeGenOrTempType(const LLTCodeGen &LLT) : Data(LLT) {}
+ LLTCodeGenOrTempType(TempTypeIdx TempTy) : Data(TempTy) {}
+
+ bool isLLTCodeGen() const { return std::holds_alternative<LLTCodeGen>(Data); }
+ bool isTempTypeIdx() const {
+ return std::holds_alternative<TempTypeIdx>(Data);
+ }
+
+ const LLTCodeGen &getLLTCodeGen() const {
+ assert(isLLTCodeGen());
+ return std::get<LLTCodeGen>(Data);
+ }
+
+ TempTypeIdx getTempTypeIdx() const {
+ assert(isTempTypeIdx());
+ return std::get<TempTypeIdx>(Data);
+ }
+
+private:
+ std::variant<LLTCodeGen, TempTypeIdx> Data;
+};
+
+inline MatchTable &operator<<(MatchTable &Table,
+ const LLTCodeGenOrTempType &Ty) {
+ if (Ty.isLLTCodeGen())
+ Table << MatchTable::NamedValue(Ty.getLLTCodeGen().getCxxEnumValue());
+ else
+ Table << MatchTable::IntValue(Ty.getTempTypeIdx());
+ return Table;
+}
+
//===- Matchers -----------------------------------------------------------===//
class Matcher {
public:
@@ -459,6 +493,9 @@ class RuleMatcher : public Matcher {
/// ID for the next temporary register ID allocated with allocateTempRegID()
unsigned NextTempRegID;
+ /// ID for the next recorded type. Starts at -1 and counts down.
+ TempTypeIdx NextTempTypeIdx = -1;
+
// HwMode predicate index for this rule. -1 if no HwMode.
int HwModeIdx = -1;
@@ -498,6 +535,8 @@ class RuleMatcher : public Matcher {
RuleMatcher(RuleMatcher &&Other) = default;
RuleMatcher &operator=(RuleMatcher &&Other) = default;
+ TempTypeIdx getNextTempTypeIdx() { return NextTempTypeIdx--; }
+
uint64_t getRuleID() const { return RuleID; }
InstructionMatcher &addInstructionMatcher(StringRef SymbolicName);
@@ -602,6 +641,7 @@ class RuleMatcher : public Matcher {
}
InstructionMatcher &getInstructionMatcher(StringRef SymbolicName) const;
+ OperandMatcher &getOperandMatcher(StringRef Name);
const OperandMatcher &getOperandMatcher(StringRef Name) const;
const OperandMatcher &getPhysRegOperandMatcher(Record *) const;
@@ -762,6 +802,7 @@ class PredicateMatcher {
OPM_RegBank,
OPM_MBB,
OPM_RecordNamedOperand,
+ OPM_RecordRegType,
};
protected:
@@ -963,6 +1004,30 @@ class RecordNamedOperandMatcher : public OperandPredicateMatcher {
RuleMatcher &Rule) const override;
};
+/// Generates code to store a register operand's type into the set of temporary
+/// LLTs.
+class RecordRegisterType : public OperandPredicateMatcher {
+protected:
+ TempTypeIdx Idx;
+
+public:
+ RecordRegisterType(unsigned InsnVarID, unsigned OpIdx, TempTypeIdx Idx)
+ : OperandPredicateMatcher(OPM_RecordRegType, InsnVarID, OpIdx), Idx(Idx) {
+ }
+
+ static bool classof(const PredicateMatcher *P) {
+ return P->getKind() == OPM_RecordRegType;
+ }
+
+ bool isIdentical(const PredicateMatcher &B) const override {
+ return OperandPredicateMatcher::isIdentical(B) &&
+ Idx == cast<RecordRegisterType>(&B)->Idx;
+ }
+
+ void emitPredicateOpcodes(MatchTable &Table,
+ RuleMatcher &Rule) const override;
+};
+
/// Generates code to check that an operand is a particular target constant.
class ComplexPatternOperandMatcher : public OperandPredicateMatcher {
protected:
@@ -1169,6 +1234,8 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
/// countRendererFns().
unsigned AllocatedTemporariesBaseID;
+ TempTypeIdx TTIdx = 0;
+
public:
OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx,
const std::string &SymbolicName,
@@ -1196,6 +1263,11 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
unsigned getOpIdx() const { return OpIdx; }
unsigned getInsnVarID() const;
+ /// If this OperandMatcher has not been assigned a TempTypeIdx yet, assigns it
+ /// one and adds a `RecordRegisterType` predicate to this matcher. If one has
+ /// already been assigned, simply returns it.
+ TempTypeIdx getTempTypeIdx(RuleMatcher &Rule);
+
std::string getOperandExpr(unsigned InsnVarID) const;
InstructionMatcher &getInstructionMatcher() const { return Insn; }
@@ -1955,15 +2027,16 @@ class ImmRenderer : public OperandRenderer {
protected:
unsigned InsnID;
int64_t Imm;
- std::optional<LLTCodeGen> CImmLLT;
+ std::optional<LLTCodeGenOrTempType> CImmLLT;
public:
ImmRenderer(unsigned InsnID, int64_t Imm)
: OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm) {}
- ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGen &CImmLLT)
+ ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGenOrTempType &CImmLLT)
: OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm), CImmLLT(CImmLLT) {
- KnownTypes.insert(CImmLLT);
+ if (CImmLLT.isLLTCodeGen())
+ KnownTypes.insert(CImmLLT.getLLTCodeGen());
}
static bool classof(const OperandRenderer *R) {
@@ -1976,8 +2049,7 @@ class ImmRenderer : public OperandRenderer {
"ConstantInt immediate are only for combiners!");
Table << MatchTable::Opcode("GIR_AddCImm")
<< MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
- << MatchTable::Comment("Type")
- << MatchTable::NamedValue(CImmLLT->getCxxEnumValue())
+ << MatchTable::Comment("Type") << *CImmLLT
<< MatchTable::Comment("Imm") << MatchTable::IntValue(Imm)
<< MatchTable::LineBreak;
} else {
@@ -2290,13 +2362,14 @@ class ConstrainOperandToRegClassAction : public MatchAction {
/// instructions together.
class MakeTempRegisterAction : public MatchAction {
private:
- LLTCodeGen Ty;
+ LLTCodeGenOrTempType Ty;
unsigned TempRegID;
public:
- MakeTempRegisterAction(const LLTCodeGen &Ty, unsigned TempRegID)
+ MakeTempRegisterAction(const LLTCodeGenOrTempType &Ty, unsigned TempRegID)
: MatchAction(AK_MakeTempReg), Ty(Ty), TempRegID(TempRegID) {
- KnownTypes.insert(Ty);
+ if (Ty.isLLTCodeGen())
+ KnownTypes.insert(Ty.getLLTCodeGen());
}
static bool classof(const MatchAction *A) {
More information about the llvm-commits
mailing list