[llvm] 550d93a - [RISCV] Combine comparison and logic ops
Sergey Kachkov via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 23 06:10:40 PST 2022
Author: Ilya Andreev
Date: 2022-12-23T17:10:21+03:00
New Revision: 550d93ab1d2ec27efe5c5791f16ef31e3f74a6a6
URL: https://github.com/llvm/llvm-project/commit/550d93ab1d2ec27efe5c5791f16ef31e3f74a6a6
DIFF: https://github.com/llvm/llvm-project/commit/550d93ab1d2ec27efe5c5791f16ef31e3f74a6a6.diff
LOG: [RISCV] Combine comparison and logic ops
Two comparison operations and a logical operation are combined into selection using MIN or MAX and comparison operation.
For optimization to be applied conditions have to be satisfied:
1. In comparison operations has to be the one common operand.
2. Supports only signed and unsigned integers.
3. Comparison has to be the same with respect to common operand.
4. There are no more users of comparison except logic operation.
5. Every combination of comparison and AND, OR are supported.
It will convert
%l0 = %a < %c
%l1 = %b < %c
%res = %l0 or %l1
into
%sel = min(%a, %b)
%res = %sel < %c
It supports several comparison operations (<, <=, >, >=), signed, unsigned values and different order of operands if they do not violate conditions.
Differential Revision: https://reviews.llvm.org/D134277
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 764802f931e47..3041e4685b2ae 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8465,6 +8465,226 @@ static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Helper class contains information about comparison operation.
+// The first two operands of this operation are compared values and the
+// last one is the operation.
+// Compared values are stored in Ops.
+// Comparison operation is stored in CCode.
+class CmpOpInfo {
+ static unsigned constexpr Size = 2u;
+
+ // Type for storing operands of compare operation.
+ using OpsArray = std::array<SDValue, Size>;
+ OpsArray Ops;
+
+ using const_iterator = OpsArray::const_iterator;
+ const_iterator begin() const { return Ops.begin(); }
+ const_iterator end() const { return Ops.end(); }
+
+ ISD::CondCode CCode;
+
+ unsigned CommonPos{Size};
+ unsigned DifferPos{Size};
+
+ // Sets CommonPos and DifferPos based on incoming position
+ // of common operand CPos.
+ void setPositions(const_iterator CPos) {
+ assert(CPos != Ops.end() && "Common operand has to be in OpsArray.\n");
+ CommonPos = CPos == Ops.begin() ? 0 : 1;
+ DifferPos = 1 - CommonPos;
+ assert((DifferPos == 0 || DifferPos == 1) &&
+ "Positions can be only 0 or 1.");
+ }
+
+ // Private constructor of comparison info based on comparison operator.
+ // It is private because CmpOpInfo only reasonable relative to other
+ // comparison operator. Therefore, infos about comparison operation
+ // have to be collected simultaneously via CmpOpInfo::getInfoAbout().
+ CmpOpInfo(const SDValue &CmpOp)
+ : Ops{CmpOp.getOperand(0), CmpOp.getOperand(1)},
+ CCode{cast<CondCodeSDNode>(CmpOp.getOperand(2))->get()} {}
+
+ // Finds common operand of Op1 and Op2 and finishes filling CmpOpInfos.
+ // Returns true if common operand is found. Otherwise - false.
+ static bool establishCorrespondence(CmpOpInfo &Op1, CmpOpInfo &Op2) {
+ const auto CommonOpIt1 =
+ std::find_first_of(Op1.begin(), Op1.end(), Op2.begin(), Op2.end());
+ if (CommonOpIt1 == Op1.end())
+ return false;
+
+ const auto CommonOpIt2 = std::find(Op2.begin(), Op2.end(), *CommonOpIt1);
+ assert(CommonOpIt2 != Op2.end() &&
+ "Cannot find common operand in the second comparison operation.");
+
+ Op1.setPositions(CommonOpIt1);
+ Op2.setPositions(CommonOpIt2);
+
+ return true;
+ }
+
+public:
+ CmpOpInfo(const CmpOpInfo &) = default;
+ CmpOpInfo(CmpOpInfo &&) = default;
+
+ SDValue const &operator[](unsigned Pos) const {
+ assert(Pos < Size && "Out of range\n");
+ return Ops[Pos];
+ }
+
+ // Creates infos about comparison operations CmpOp0 and CmpOp1.
+ // If there is no common operand returns None. Otherwise, returns
+ // correspondence info about comparison operations.
+ static std::optional<std::pair<CmpOpInfo, CmpOpInfo>>
+ getInfoAbout(SDValue const &CmpOp0, SDValue const &CmpOp1) {
+ CmpOpInfo Op0{CmpOp0};
+ CmpOpInfo Op1{CmpOp1};
+ if (!establishCorrespondence(Op0, Op1))
+ return std::nullopt;
+ return std::make_pair(Op0, Op1);
+ }
+
+ // Returns position of common operand.
+ unsigned getCPos() const { return CommonPos; }
+
+ // Returns position of
diff er operand.
+ unsigned getDPos() const { return DifferPos; }
+
+ // Returns common operand.
+ SDValue const &getCOp() const { return operator[](CommonPos); }
+
+ // Returns
diff er operand.
+ SDValue const &getDOp() const { return operator[](DifferPos); }
+
+ // Returns consition code of comparison operation.
+ ISD::CondCode getCondCode() const { return CCode; }
+};
+
+// Verifies conditions to apply an optimization.
+// Returns Reference comparison code and three operands A, B, C.
+// Conditions for optimization:
+// One operand of the compasions has to be common.
+// This operand is written to C.
+// Two others operands are
diff erend. They are written to A and B.
+// Comparisons has to be similar with respect to common operand C.
+// e.g. A < C; C > B are similar
+// but A < C; B > C are not.
+// Reference comparison code is the comparison code if
+// common operand is right placed.
+// e.g. C > A will be swapped to A < C.
+static std::optional<std::tuple<ISD::CondCode, SDValue, SDValue, SDValue>>
+verifyCompareConds(SDNode *N, SelectionDAG &DAG) {
+ LLVM_DEBUG(
+ dbgs() << "Checking conditions for comparison operation combining.\n";);
+
+ SDValue V0 = N->getOperand(0);
+ SDValue V1 = N->getOperand(1);
+ assert(V0.getValueType() == V1.getValueType() &&
+ "Operations must have the same value type.");
+
+ // Condition 1. Operations have to be used only in logic operation.
+ if (!V0.hasOneUse() || !V1.hasOneUse())
+ return std::nullopt;
+
+ // Condition 2. Operands have to be comparison operations.
+ if (V0.getOpcode() != ISD::SETCC || V1.getOpcode() != ISD::SETCC)
+ return std::nullopt;
+
+ // Condition 3.1. Operations only with integers.
+ if (!V0.getOperand(0).getValueType().isInteger())
+ return std::nullopt;
+
+ const auto ComparisonInfo = CmpOpInfo::getInfoAbout(V0, V1);
+ // Condition 3.2. Common operand has to be in comparison.
+ if (!ComparisonInfo)
+ return std::nullopt;
+
+ const auto [Op0, Op1] = ComparisonInfo.value();
+
+ LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << Op0.getCPos()
+ << " and " << Op1.getCPos() << '\n';);
+ // If common operand at the first position then swap operation to convert to
+ // strict pattern. Common operand has to be right hand side.
+ ISD::CondCode RefCond = Op0.getCondCode();
+ ISD::CondCode AssistCode = Op1.getCondCode();
+ if (!Op0.getCPos())
+ RefCond = ISD::getSetCCSwappedOperands(RefCond);
+ if (!Op1.getCPos())
+ AssistCode = ISD::getSetCCSwappedOperands(AssistCode);
+ LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';);
+ // If there are
diff erent comparison operations then do not perform an
+ // optimization. a < c; c < b -> will be changed to b > c.
+ if (RefCond != AssistCode)
+ return std::nullopt;
+
+ // Conditions can be only similar to Less or Greater. (>, >=, <, <=)
+ // Applying this mask to the operation will determine Less and Greater
+ // operations.
+ const unsigned CmpMask = 0b110;
+ const unsigned MaskedOpcode = CmpMask & RefCond;
+ // If masking gave 0b110, then this is an operation NE, O or TRUE.
+ if (MaskedOpcode == CmpMask)
+ return std::nullopt;
+ // If masking gave 00000, then this is an operation E, O or FALSE.
+ if (MaskedOpcode == 0)
+ return std::nullopt;
+ // Everything else is similar to Less or Greater.
+
+ SDValue A = Op0.getDOp();
+ SDValue B = Op1.getDOp();
+ SDValue C = Op0.getCOp();
+
+ LLVM_DEBUG(
+ dbgs() << "The conditions for combining comparisons are satisfied.\n";);
+ return std::make_tuple(RefCond, A, B, C);
+}
+
+static ISD::NodeType getSelectionCode(bool IsUnsigned, bool IsAnd,
+ bool IsGreaterOp) {
+ // Codes of selection operation. The first index selects signed or unsigned,
+ // the second index selects MIN/MAX.
+ static constexpr ISD::NodeType SelectionCodes[2][2] = {
+ {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}};
+ const bool ChooseSelCode = IsAnd ^ IsGreaterOp;
+ return SelectionCodes[IsUnsigned][ChooseSelCode];
+}
+
+// Combines two comparison operation and logic operation to one selection
+// operation(min, max) and logic operation. Returns new constructed Node if
+// conditions for optimization are satisfied.
+static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (!Subtarget.hasStdExtZbb())
+ return SDValue();
+
+ const unsigned BitOpcode = N->getOpcode();
+ assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) &&
+ "This optimization can be used only with AND/OR operations");
+
+ const auto Props = verifyCompareConds(N, DAG);
+ // If conditions are invalidated then do not perform an optimization.
+ if (!Props)
+ return SDValue();
+
+ const auto [RefOpcode, A, B, C] = Props.value();
+ const EVT CmpOpVT = A.getValueType();
+
+ const bool IsGreaterOp = RefOpcode & 0b10;
+ const bool IsUnsigned = ISD::isUnsignedIntSetCC(RefOpcode);
+ assert((IsUnsigned || ISD::isSignedIntSetCC(RefOpcode)) &&
+ "Operation neither with signed or unsigned integers.");
+
+ const bool IsAnd = BitOpcode == ISD::AND;
+ const ISD::NodeType PickCode =
+ getSelectionCode(IsUnsigned, IsAnd, IsGreaterOp);
+
+ SDLoc DL(N);
+ SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B);
+ SDValue Cmp =
+ DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode);
+
+ return Cmp;
+}
+
static SDValue performANDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
@@ -8489,6 +8709,9 @@ static SDValue performANDCombine(SDNode *N,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
}
+ if (SDValue V = combineCmpOp(N, DAG, Subtarget))
+ return V;
+
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
@@ -8505,6 +8728,9 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
+ if (SDValue V = combineCmpOp(N, DAG, Subtarget))
+ return V;
+
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
diff --git a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll
index 4d02c539fa0b2..b94c50dd7e8c8 100644
--- a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll
+++ b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll
@@ -12,9 +12,8 @@
define i1 @ulo(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ulo:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ult i64 %a, %c
%l1 = icmp ult i64 %b, %c
@@ -25,9 +24,8 @@ define i1 @ulo(i64 %c, i64 %a, i64 %b) {
define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ulo_swap1:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %c, %a
%l1 = icmp ult i64 %b, %c
@@ -38,9 +36,8 @@ define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) {
define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ulo_swap2:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ult i64 %a, %c
%l1 = icmp ugt i64 %c, %b
@@ -51,9 +48,8 @@ define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) {
define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ulo_swap12:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %c, %a
%l1 = icmp ugt i64 %c, %b
@@ -65,9 +61,8 @@ define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) {
define i1 @ula(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ula:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ult i64 %a, %c
%l1 = icmp ult i64 %b, %c
@@ -78,9 +73,8 @@ define i1 @ula(i64 %c, i64 %a, i64 %b) {
define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ula_swap1:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %c, %a
%l1 = icmp ult i64 %b, %c
@@ -91,9 +85,8 @@ define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) {
define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ula_swap2:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ult i64 %a, %c
%l1 = icmp ugt i64 %c, %b
@@ -104,9 +97,8 @@ define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) {
define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ula_swap12:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %c, %a
%l1 = icmp ugt i64 %c, %b
@@ -119,9 +111,8 @@ define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) {
define i1 @ugo(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ugo:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a0, a1
-; CHECK-NEXT: sltu a0, a0, a2
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a0, a1
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %a, %c
%l1 = icmp ugt i64 %b, %c
@@ -132,9 +123,8 @@ define i1 @ugo(i64 %c, i64 %a, i64 %b) {
define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ugo_swap1:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a0, a1
-; CHECK-NEXT: sltu a0, a0, a2
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a0, a1
; CHECK-NEXT: ret
%l0 = icmp ult i64 %c, %a
%l1 = icmp ugt i64 %b, %c
@@ -145,9 +135,8 @@ define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) {
define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ugo_swap2:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a0, a1
-; CHECK-NEXT: sltu a0, a0, a2
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a0, a1
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %a, %c
%l1 = icmp ult i64 %c, %b
@@ -158,9 +147,8 @@ define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) {
define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ugo_swap12:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a0, a1
-; CHECK-NEXT: sltu a0, a0, a2
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: maxu a1, a1, a2
+; CHECK-NEXT: sltu a0, a0, a1
; CHECK-NEXT: ret
%l0 = icmp ult i64 %c, %a
%l1 = icmp ult i64 %c, %b
@@ -173,9 +161,8 @@ define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) {
define i1 @ugea(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: ugea:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: sltu a0, a2, a0
-; CHECK-NEXT: or a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a1, a0
; CHECK-NEXT: xori a0, a0, 1
; CHECK-NEXT: ret
%l0 = icmp uge i64 %a, %c
@@ -189,9 +176,8 @@ define i1 @ugea(i64 %c, i64 %a, i64 %b) {
define i1 @uga(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: uga:
; CHECK: # %bb.0:
-; CHECK-NEXT: sltu a1, a0, a1
-; CHECK-NEXT: sltu a0, a0, a2
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: minu a1, a1, a2
+; CHECK-NEXT: sltu a0, a0, a1
; CHECK-NEXT: ret
%l0 = icmp ugt i64 %a, %c
%l1 = icmp ugt i64 %b, %c
@@ -204,9 +190,8 @@ define i1 @uga(i64 %c, i64 %a, i64 %b) {
define i1 @sla(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: sla:
; CHECK: # %bb.0:
-; CHECK-NEXT: slt a1, a1, a0
-; CHECK-NEXT: slt a0, a2, a0
-; CHECK-NEXT: and a0, a1, a0
+; CHECK-NEXT: max a1, a1, a2
+; CHECK-NEXT: slt a0, a1, a0
; CHECK-NEXT: ret
%l0 = icmp slt i64 %a, %c
%l1 = icmp slt i64 %b, %c
@@ -214,6 +199,7 @@ define i1 @sla(i64 %c, i64 %a, i64 %b) {
ret i1 %res
}
+; Negative test
; Float check.
define i1 @flo(float %c, float %a, float %b) {
; CHECK-RV64I-LABEL: flo:
@@ -259,6 +245,7 @@ define i1 @flo(float %c, float %a, float %b) {
ret i1 %res
}
+; Negative test
; Double check.
define i1 @dlo(double %c, double %a, double %b) {
; CHECK-LABEL: dlo:
@@ -296,6 +283,7 @@ define i1 @dlo(double %c, double %a, double %b) {
ret i1 %res
}
+; Negative test
; More than one user
define i1 @multi_user(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: multi_user:
@@ -313,6 +301,7 @@ define i1 @multi_user(i64 %c, i64 %a, i64 %b) {
ret i1 %out
}
+; Negative test
; No same comparations
define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) {
; CHECK-LABEL: no_same_ops:
More information about the llvm-commits
mailing list