[llvm] 9effe38 - [AArch64][GlobalISel] Fold G_XOR into TB(N)Z bit calculation

Jessica Paquette via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 3 15:25:23 PST 2020


Author: Jessica Paquette
Date: 2020-02-03T15:22:24-08:00
New Revision: 9effe38b225f3dfd72d6f1800f2ea47175b5bf95

URL: https://github.com/llvm/llvm-project/commit/9effe38b225f3dfd72d6f1800f2ea47175b5bf95
DIFF: https://github.com/llvm/llvm-project/commit/9effe38b225f3dfd72d6f1800f2ea47175b5bf95.diff

LOG: [AArch64][GlobalISel] Fold G_XOR into TB(N)Z bit calculation

This ports the existing case for G_XOR from `getTestBitOperand` in
AArch64ISelLowering into GlobalISel.

The idea is to flip between TBZ and TBNZ while walking through G_XORs.

Let's say we have

```
tbz (xor x, c), b
```

Let's say the `b`-th bit in `c` is 1. Then

- If the `b`-th bit in `x` is 1, the `b`-th bit in `(xor x, c)` is 0.
- If the `b`-th bit in `x` is 0, then the `b`-th bit in `(xor x, c)` is 1.

So, then

```
tbz (xor x, c), b == tbnz x, b
```

Let's say the `b`-th bit in `c` is 0. Then

- If the `b`-th bit in `x` is 1, the `b`-th bit in `(xor x, c)` is 1.
- If the `b`-th bit in `x` is 0, then the `b`-th bit in `(xor x, c)` is 0.

So, then

```
tbz (xor x, c), b == tbz x, b
```

Differential Revision: https://reviews.llvm.org/D73929

Added: 
    llvm/test/CodeGen/AArch64/GlobalISel/opt-fold-xor-tbz-tbnz.mir

Modified: 
    llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
index e7d90bf1af44..f6f710826f59 100644
--- a/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
@@ -991,7 +991,7 @@ static void changeFCMPPredToAArch64CC(CmpInst::Predicate P,
 }
 
 /// Return a register which can be used as a bit to test in a TB(N)Z.
-static Register getTestBitReg(Register Reg, uint64_t &Bit,
+static Register getTestBitReg(Register Reg, uint64_t &Bit, bool &Invert,
                               MachineRegisterInfo &MRI) {
   assert(Reg.isValid() && "Expected valid register!");
   while (MachineInstr *MI = getDefIgnoringCopies(Reg, MRI)) {
@@ -1018,7 +1018,8 @@ static Register getTestBitReg(Register Reg, uint64_t &Bit,
     switch (Opc) {
     default:
       break;
-    case TargetOpcode::G_AND: {
+    case TargetOpcode::G_AND:
+    case TargetOpcode::G_XOR: {
       TestReg = MI->getOperand(1).getReg();
       Register ConstantReg = MI->getOperand(2).getReg();
       auto VRegAndVal = getConstantVRegValWithLookThrough(ConstantReg, MRI);
@@ -1066,6 +1067,19 @@ static Register getTestBitReg(Register Reg, uint64_t &Bit,
         Bit = Bit - *C;
       }
       break;
+    case TargetOpcode::G_XOR:
+      // We can walk through a G_XOR by inverting whether we use tbz/tbnz when
+      // appropriate.
+      //
+      // e.g. If x' = xor x, c, and the b-th bit is set in c then
+      //
+      // tbz x', b -> tbnz x, b
+      //
+      // Because x' only has the b-th bit set if x does not.
+      if ((*C >> Bit) & 1)
+        Invert = !Invert;
+      NextReg = TestReg;
+      break;
     }
 
     // Check if we found anything worth folding.
@@ -1124,20 +1138,21 @@ bool AArch64InstructionSelector::tryOptAndIntoCompareBranch(
   // Try to optimize the TB(N)Z.
   uint64_t Bit = Log2_64(static_cast<uint64_t>(MaybeBit->Value));
   Register TestReg = AndInst->getOperand(1).getReg();
-  TestReg = getTestBitReg(TestReg, Bit, MRI);
+  bool Invert = Pred == CmpInst::Predicate::ICMP_NE;
+  TestReg = getTestBitReg(TestReg, Bit, Invert, MRI);
 
   // Choose the correct TB(N)Z opcode to use.
   unsigned Opc = 0;
   if (Bit < 32) {
     // When the bit is less than 32, we have to use a TBZW even if we're on a 64
     // bit register.
-    Opc = Pred == CmpInst::Predicate::ICMP_EQ ? AArch64::TBZW : AArch64::TBNZW;
+    Opc = Invert ? AArch64::TBNZW : AArch64::TBZW;
     TestReg = narrowExtendRegIfNeeded(TestReg, MIB);
   } else {
     // Same idea for when Bit >= 32. We don't have to narrow here, because if
     // Bit > 32, then the G_CONSTANT must be outside the range of valid 32-bit
     // values. So, we must have a s64.
-    Opc = Pred == CmpInst::Predicate::ICMP_EQ ? AArch64::TBZX : AArch64::TBNZX;
+    Opc = Invert ? AArch64::TBNZX : AArch64::TBZX;
   }
 
   // Construct the branch.

diff  --git a/llvm/test/CodeGen/AArch64/GlobalISel/opt-fold-xor-tbz-tbnz.mir b/llvm/test/CodeGen/AArch64/GlobalISel/opt-fold-xor-tbz-tbnz.mir
new file mode 100644
index 000000000000..8e19ba41b2c0
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/opt-fold-xor-tbz-tbnz.mir
@@ -0,0 +1,188 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple aarch64-unknown-unknown -run-pass=instruction-select -verify-machineinstrs %s -o - | FileCheck %s
+...
+---
+name:            flip_eq
+alignment:       4
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: flip_eq
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %copy:gpr64all = COPY $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:gpr32all = COPY %copy.sub_32
+  ; CHECK:   [[COPY1:%[0-9]+]]:gpr32 = COPY [[COPY]]
+  ; CHECK:   TBNZW [[COPY1]], 3, %bb.1
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    successors: %bb.0, %bb.1
+    liveins: $x0
+    %copy:gpr(s64) = COPY $x0
+
+    ; Check bit 3.
+    %bit:gpr(s64) = G_CONSTANT i64 8
+    %zero:gpr(s64) = G_CONSTANT i64 0
+
+    ; 8 has the third bit set.
+    %fold_cst:gpr(s64) = G_CONSTANT i64 8
+
+    ; This only has the third bit set if %copy does not. So, to walk through
+    ; this, we want to use a TBNZW on %copy.
+    %fold_me:gpr(s64) = G_XOR %copy, %fold_cst
+
+    %and:gpr(s64) = G_AND %fold_me, %bit
+    %cmp:gpr(s32) = G_ICMP intpred(eq), %and(s64), %zero
+    %cmp_trunc:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cmp_trunc(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR
+...
+---
+name:            flip_ne
+alignment:       4
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: flip_ne
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %copy:gpr64all = COPY $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:gpr32all = COPY %copy.sub_32
+  ; CHECK:   [[COPY1:%[0-9]+]]:gpr32 = COPY [[COPY]]
+  ; CHECK:   TBZW [[COPY1]], 3, %bb.1
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    successors: %bb.0, %bb.1
+    liveins: $x0
+
+    ; Same as eq case, but we should get a TBZW instead.
+
+    %copy:gpr(s64) = COPY $x0
+    %bit:gpr(s64) = G_CONSTANT i64 8
+    %zero:gpr(s64) = G_CONSTANT i64 0
+    %fold_cst:gpr(s64) = G_CONSTANT i64 8
+    %fold_me:gpr(s64) = G_XOR %copy, %fold_cst
+    %and:gpr(s64) = G_AND %fold_me, %bit
+    %cmp:gpr(s32) = G_ICMP intpred(ne), %and(s64), %zero
+    %cmp_trunc:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cmp_trunc(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR
+...
+---
+name:            dont_flip_eq
+alignment:       4
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: dont_flip_eq
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %copy:gpr64all = COPY $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:gpr32all = COPY %copy.sub_32
+  ; CHECK:   [[COPY1:%[0-9]+]]:gpr32 = COPY [[COPY]]
+  ; CHECK:   TBZW [[COPY1]], 3, %bb.1
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    successors: %bb.0, %bb.1
+    liveins: $x0
+    %copy:gpr(s64) = COPY $x0
+
+    ; Check bit 3.
+    %bit:gpr(s64) = G_CONSTANT i64 8
+    %zero:gpr(s64) = G_CONSTANT i64 0
+
+    ; 7 does not have the third bit set.
+    %fold_cst:gpr(s64) = G_CONSTANT i64 7
+
+    ; This only has the third bit set if %copy does. So, to walk through this,
+    ; we should have a TBZW on %copy.
+    %fold_me:gpr(s64) = G_XOR %fold_cst, %copy
+
+    %and:gpr(s64) = G_AND %fold_me, %bit
+    %cmp:gpr(s32) = G_ICMP intpred(eq), %and(s64), %zero
+    %cmp_trunc:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cmp_trunc(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR
+...
+---
+name:            dont_flip_ne
+alignment:       4
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: dont_flip_ne
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %copy:gpr64all = COPY $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:gpr32all = COPY %copy.sub_32
+  ; CHECK:   [[COPY1:%[0-9]+]]:gpr32 = COPY [[COPY]]
+  ; CHECK:   TBNZW [[COPY1]], 3, %bb.1
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    successors: %bb.0, %bb.1
+    liveins: $x0
+
+    ; Same as eq case, but we should get a TBNZW instead.
+
+    %copy:gpr(s64) = COPY $x0
+    %bit:gpr(s64) = G_CONSTANT i64 8
+    %zero:gpr(s64) = G_CONSTANT i64 0
+    %fold_cst:gpr(s64) = G_CONSTANT i64 7
+    %fold_me:gpr(s64) = G_XOR %fold_cst, %copy
+    %and:gpr(s64) = G_AND %fold_me, %bit
+    %cmp:gpr(s32) = G_ICMP intpred(ne), %and(s64), %zero
+    %cmp_trunc:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cmp_trunc(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR
+...
+---
+name:            xor_chain
+alignment:       4
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: xor_chain
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %copy:gpr64all = COPY $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:gpr32all = COPY %copy.sub_32
+  ; CHECK:   [[COPY1:%[0-9]+]]:gpr32 = COPY [[COPY]]
+  ; CHECK:   TBZW [[COPY1]], 3, %bb.1
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    successors: %bb.0, %bb.1
+    liveins: $x0
+    %copy:gpr(s64) = COPY $x0
+    %bit:gpr(s64) = G_CONSTANT i64 8
+    %zero:gpr(s64) = G_CONSTANT i64 0
+    %fold_cst:gpr(s64) = G_CONSTANT i64 8
+
+    ; The G_XORs cancel each other out, so we should get a TBZW.
+    %xor1:gpr(s64) = G_XOR %copy, %fold_cst
+    %xor2:gpr(s64) = G_XOR %xor1, %fold_cst
+
+    %and:gpr(s64) = G_AND %xor2, %bit
+    %cmp:gpr(s32) = G_ICMP intpred(eq), %and(s64), %zero
+    %cmp_trunc:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cmp_trunc(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR


        


More information about the llvm-commits mailing list