[llvm] [RISCV] Add a MIR pass to reassociate shXadd, add, and slli to form more shXadd. (PR #87544)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 09:44:04 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/87544

>From eec429928f120ed6db17af0c134aa403befaeb8d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 3 Apr 2024 11:58:18 -0700
Subject: [PATCH 1/5] [RISCV] Add tests for reassociating to form more shXadd
 instructions.

---
 llvm/test/CodeGen/RISCV/rv64zba.ll | 361 +++++++++++++++++++++++++++++
 1 file changed, 361 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index c81c6aeaab890c..7e32253c8653f1 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -2036,3 +2036,364 @@ define i64 @pack_i64_disjoint_2(i32 signext %a, i64 %b) nounwind {
   %or = or disjoint i64 %b, %zexta
   ret i64 %or
 }
+
+define i8 @array_index_sh1_sh0(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh1_sh0:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 1
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    lbu a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh1_sh0:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh1add a0, a1, a0
+; RV64ZBA-NEXT:    add a0, a0, a2
+; RV64ZBA-NEXT:    lbu a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [2 x i8], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i8, ptr %a, align 1
+  ret i8 %b
+}
+
+define i16 @array_index_sh1_sh1(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh1_sh1:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 2
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 1
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lh a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh1_sh1:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
+; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    lh a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [2 x i16], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i16, ptr %a, align 2
+  ret i16 %b
+}
+
+define i32 @array_index_sh1_sh2(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh1_sh2:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 3
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lw a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh1_sh2:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    lw a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [2 x i32], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i32, ptr %a, align 4
+  ret i32 %b
+}
+
+define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh1_sh3:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 4
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    ld a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh1_sh3:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 4
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    ld a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i64, ptr %a, align 8
+  ret i64 %b
+}
+
+define i8 @array_index_sh2_sh0(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh2_sh0:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    lbu a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh2_sh0:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
+; RV64ZBA-NEXT:    add a0, a0, a2
+; RV64ZBA-NEXT:    lbu a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [4 x i8], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i8, ptr %a, align 1
+  ret i8 %b
+}
+
+define i16 @array_index_sh2_sh1(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh2_sh1:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 3
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 1
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lh a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh2_sh1:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    lh a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [4 x i16], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i16, ptr %a, align 2
+  ret i16 %b
+}
+
+define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh2_sh2:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 4
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lw a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh2_sh2:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 4
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    lw a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i32, ptr %a, align 4
+  ret i32 %b
+}
+
+define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh2_sh3:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 5
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    ld a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh2_sh3:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 5
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    ld a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i64, ptr %a, align 8
+  ret i64 %b
+}
+
+define i8 @array_index_sh3_sh0(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh3_sh0:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    lbu a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh3_sh0:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    add a0, a0, a2
+; RV64ZBA-NEXT:    lbu a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [8 x i8], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i8, ptr %a, align 1
+  ret i8 %b
+}
+
+define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh3_sh1:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 4
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 1
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lh a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh3_sh1:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 4
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    lh a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i16, ptr %a, align 2
+  ret i16 %b
+}
+
+define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh3_sh2:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 5
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lw a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh3_sh2:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 5
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    lw a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i32, ptr %a, align 4
+  ret i32 %b
+}
+
+define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh3_sh3:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 6
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    ld a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh3_sh3:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 6
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    ld a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i64, ptr %a, align 8
+  ret i64 %b
+}
+
+; Similar to above, but with a lshr on one of the indices. This requires
+; special handling during isel to form a shift pair.
+define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_lshr_sh3_sh3:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    srli a1, a1, 58
+; RV64I-NEXT:    slli a1, a1, 6
+; RV64I-NEXT:    slli a2, a2, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    ld a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    srli a1, a1, 58
+; RV64ZBA-NEXT:    slli a1, a1, 6
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    ld a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %shr = lshr i64 %idx1, 58
+  %a = getelementptr inbounds [8 x i64], ptr %p, i64 %shr, i64 %idx2
+  %b = load i64, ptr %a, align 8
+  ret i64 %b
+}
+
+define i8 @array_index_sh4_sh0(ptr %p, i64 %idx1, i64 %idx2) {
+; CHECK-LABEL: array_index_sh4_sh0:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a1, a1, 4
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a0, a0, a1
+; CHECK-NEXT:    lbu a0, 0(a0)
+; CHECK-NEXT:    ret
+  %a = getelementptr inbounds [16 x i8], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i8, ptr %a, align 1
+  ret i8 %b
+}
+
+define i16 @array_index_sh4_sh1(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh4_sh1:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 5
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 1
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lh a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh4_sh1:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 5
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    lh a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [16 x i16], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i16, ptr %a, align 2
+  ret i16 %b
+}
+
+define i32 @array_index_sh4_sh2(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh4_sh2:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 6
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lw a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh4_sh2:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 6
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    lw a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [16 x i32], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i32, ptr %a, align 4
+  ret i32 %b
+}
+
+define i64 @array_index_sh4_sh3(ptr %p, i64 %idx1, i64 %idx2) {
+; RV64I-LABEL: array_index_sh4_sh3:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a1, 7
+; RV64I-NEXT:    add a0, a0, a1
+; RV64I-NEXT:    slli a2, a2, 3
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    ld a0, 0(a0)
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: array_index_sh4_sh3:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    slli a1, a1, 7
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    ld a0, 0(a0)
+; RV64ZBA-NEXT:    ret
+  %a = getelementptr inbounds [16 x i64], ptr %p, i64 %idx1, i64 %idx2
+  %b = load i64, ptr %a, align 8
+  ret i64 %b
+}

>From 695bd0a7dbd31f7b981298cb45490a33a17738d3 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 2 Apr 2024 13:01:14 -0700
Subject: [PATCH 2/5] [RISCV] Add a MIR pass to reassociate shXadd, add, and
 slli to form more shXadd.

This reassociates patterns like (sh3add Z, (add X, (slli Y, 6)))
into (sh3add (sh3add Y, Z), X).

This improves a pattern that occurs in 531.deepsjeng_r. Reducing
the dynamic instruction count by 0.5%.

This may be possible to improve in SelectionDAG, but given the special
cases around shXadd formation, it's not obvious it can be done in a
robust way without adding multiple special cases.

I've used a GEP with 2 indices because that mostly closely resembles
the motivating case. Most of the test cases are the simplest GEP case.
One test has a logical right shift on an index which is closer to
the deepsjeng code. This requires special handling in isel to reverse
a DAGCombiner canonicalization that turns a pair of shifts into
(srl (and X, C1), C2).

See also #85734 which had a hacky version of a similar optimization.
---
 llvm/lib/Target/RISCV/CMakeLists.txt         |   1 +
 llvm/lib/Target/RISCV/RISCV.h                |   3 +
 llvm/lib/Target/RISCV/RISCVOptZba.cpp        | 150 +++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVTargetMachine.cpp |   3 +
 llvm/test/CodeGen/RISCV/O3-pipeline.ll       |   1 +
 llvm/test/CodeGen/RISCV/rv64zba.ll           |  40 ++---
 6 files changed, 174 insertions(+), 24 deletions(-)
 create mode 100644 llvm/lib/Target/RISCV/RISCVOptZba.cpp

diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index 8715403f3839a6..8708d3b1ab70f0 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -46,6 +46,7 @@ add_llvm_target(RISCVCodeGen
   RISCVMachineFunctionInfo.cpp
   RISCVMergeBaseOffset.cpp
   RISCVOptWInstrs.cpp
+  RISCVOptZba.cpp
   RISCVPostRAExpandPseudoInsts.cpp
   RISCVRedundantCopyElimination.cpp
   RISCVMoveMerger.cpp
diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h
index 7af543f018ccbd..a3f9949425fcc4 100644
--- a/llvm/lib/Target/RISCV/RISCV.h
+++ b/llvm/lib/Target/RISCV/RISCV.h
@@ -46,6 +46,9 @@ void initializeRISCVFoldMasksPass(PassRegistry &);
 FunctionPass *createRISCVOptWInstrsPass();
 void initializeRISCVOptWInstrsPass(PassRegistry &);
 
+FunctionPass *createRISCVOptZbaPass();
+void initializeRISCVOptZbaPass(PassRegistry &);
+
 FunctionPass *createRISCVMergeBaseOffsetOptPass();
 void initializeRISCVMergeBaseOffsetOptPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/RISCV/RISCVOptZba.cpp b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
new file mode 100644
index 00000000000000..dc19f921783c95
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
@@ -0,0 +1,150 @@
+//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This pass attempts to reassociate expressions like
+//   (sh3add Z, (add X, (slli Y, 6)))
+// To
+//   (sh3add (sh3add Y, Z), X)
+//
+// This reduces number of instructions needing by spreading the shift amount
+// across to 2 sh3adds. This can be generalized to sh1add/sh2add and other
+// shift amounts.
+//
+// TODO: We can also support slli.uw by using shXadd.uw or the inner shXadd.
+
+#include "RISCV.h"
+#include "RISCVSubtarget.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-opt-zba"
+#define RISCV_OPT_ZBA_NAME "RISC-V Optimize Zba"
+
+namespace {
+
+class RISCVOptZba : public MachineFunctionPass {
+public:
+  static char ID;
+
+  RISCVOptZba() : MachineFunctionPass(ID) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override { return RISCV_OPT_ZBA_NAME; }
+};
+
+} // end anonymous namespace
+
+char RISCVOptZba::ID = 0;
+INITIALIZE_PASS(RISCVOptZba, DEBUG_TYPE, RISCV_OPT_ZBA_NAME, false, false)
+
+FunctionPass *llvm::createRISCVOptZbaPass() { return new RISCVOptZba(); }
+
+static bool findShift(Register Reg, const MachineBasicBlock &MBB,
+                      MachineInstr *&Shift, const MachineRegisterInfo &MRI) {
+  if (!Reg.isVirtual())
+    return false;
+
+  Shift = MRI.getVRegDef(Reg);
+  if (!Shift || Shift->getOpcode() != RISCV::SLLI ||
+      Shift->getParent() != &MBB || !MRI.hasOneNonDBGUse(Reg))
+    return false;
+
+  return true;
+}
+
+bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
+  if (skipFunction(MF.getFunction()))
+    return false;
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
+  const RISCVInstrInfo &TII = *ST.getInstrInfo();
+
+  if (!ST.hasStdExtZba())
+    return false;
+
+  bool MadeChange = true;
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
+      unsigned OuterShiftAmt;
+      switch (MI.getOpcode()) {
+      default:
+        continue;
+      case RISCV::SH1ADD: OuterShiftAmt = 1; break;
+      case RISCV::SH2ADD: OuterShiftAmt = 2; break;
+      case RISCV::SH3ADD: OuterShiftAmt = 3; break;
+      }
+
+      // Second operand must be virtual.
+      Register UnshiftedReg = MI.getOperand(2).getReg();
+      if (!UnshiftedReg.isVirtual())
+        continue;
+
+      MachineInstr *Add = MRI.getVRegDef(UnshiftedReg);
+      if (!Add || Add->getOpcode() != RISCV::ADD || Add->getParent() != &MBB ||
+          !MRI.hasOneNonDBGUse(UnshiftedReg))
+        continue;
+
+      Register AddReg0 = Add->getOperand(1).getReg();
+      Register AddReg1 = Add->getOperand(2).getReg();
+
+      MachineInstr *InnerShift;
+      Register X;
+      if (findShift(AddReg0, MBB, InnerShift, MRI))
+        X = AddReg1;
+      else if (findShift(AddReg1, MBB, InnerShift, MRI))
+        X = AddReg0;
+      else
+        continue;
+
+      unsigned InnerShiftAmt = InnerShift->getOperand(2).getImm();
+
+      // The inner shift amount must be at least as large as the outer shift
+      // amount.
+      if (OuterShiftAmt > InnerShiftAmt)
+        continue;
+
+      unsigned InnerOpc;
+      switch (InnerShiftAmt - OuterShiftAmt) {
+      default:
+        continue;
+      case 0: InnerOpc = RISCV::ADD;    break;
+      case 1: InnerOpc = RISCV::SH1ADD; break;
+      case 2: InnerOpc = RISCV::SH2ADD; break;
+      case 3: InnerOpc = RISCV::SH3ADD; break;
+      }
+
+      Register Y = InnerShift->getOperand(1).getReg();
+      Register Z = MI.getOperand(1).getReg();
+
+      Register NewReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(InnerOpc), NewReg)
+          .addReg(Y)
+          .addReg(Z);
+      BuildMI(MBB, MI, MI.getDebugLoc(), TII.get(MI.getOpcode()),
+              MI.getOperand(0).getReg())
+          .addReg(NewReg)
+          .addReg(X);
+
+      MI.eraseFromParent();
+      Add->eraseFromParent();
+      InnerShift->eraseFromParent();
+      MadeChange = true;
+    }
+  }
+
+  return MadeChange;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index ae1a6f179a49e3..7c6ea2b15be0cb 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -117,6 +117,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
   initializeRISCVPostRAExpandPseudoPass(*PR);
   initializeRISCVMergeBaseOffsetOptPass(*PR);
   initializeRISCVOptWInstrsPass(*PR);
+  initializeRISCVOptZbaPass(*PR);
   initializeRISCVPreRAExpandPseudoPass(*PR);
   initializeRISCVExpandPseudoPass(*PR);
   initializeRISCVFoldMasksPass(*PR);
@@ -531,6 +532,8 @@ void RISCVPassConfig::addMachineSSAOptimization() {
   if (EnableMachineCombiner)
     addPass(&MachineCombinerID);
 
+  addPass(createRISCVOptZbaPass());
+
   if (TM->getTargetTriple().isRISCV64()) {
     addPass(createRISCVOptWInstrsPass());
   }
diff --git a/llvm/test/CodeGen/RISCV/O3-pipeline.ll b/llvm/test/CodeGen/RISCV/O3-pipeline.ll
index 90472f246918f3..8fc610bb07b7f6 100644
--- a/llvm/test/CodeGen/RISCV/O3-pipeline.ll
+++ b/llvm/test/CodeGen/RISCV/O3-pipeline.ll
@@ -112,6 +112,7 @@
 ; CHECK-NEXT:       Machine Trace Metrics
 ; CHECK-NEXT:       Lazy Machine Block Frequency Analysis
 ; CHECK-NEXT:       Machine InstCombiner
+; CHECK-NEXT:       RISC-V Optimize Zba
 ; RV64-NEXT:        RISC-V Optimize W Instructions
 ; CHECK-NEXT:       RISC-V Pre-RA pseudo instruction expansion pass
 ; CHECK-NEXT:       RISC-V Merge Base Offset
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index 7e32253c8653f1..067addc819f7e6 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
 ;
 ; RV64ZBA-LABEL: sh6_sh3_add2:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a1, a0
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ret
 entry:
   %shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh1_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh1add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh1:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh1add a0, a1, a0
 ; RV64ZBA-NEXT:    lh a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
 ; RV64ZBA-NEXT:    srli a1, a1, 58
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %shr = lshr i64 %idx1, 58

>From a8f22b1a0f0105698caa85a754d5f366a75d1a52 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 3 Apr 2024 12:17:50 -0700
Subject: [PATCH 3/5] fixup! update header comments.

---
 llvm/lib/Target/RISCV/RISCVOptZba.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVOptZba.cpp b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
index dc19f921783c95..73505f48749b82 100644
--- a/llvm/lib/Target/RISCV/RISCVOptZba.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
@@ -1,4 +1,4 @@
-//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
+//===- RISCVOptZba.cpp - MI Zba instruction optimizations -----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -16,6 +16,8 @@
 // shift amounts.
 //
 // TODO: We can also support slli.uw by using shXadd.uw or the inner shXadd.
+//
+//===---------------------------------------------------------------------===//
 
 #include "RISCV.h"
 #include "RISCVSubtarget.h"

>From bf47a44b9bfd1d36bcbc36a4a312223d1439762c Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 3 Apr 2024 12:23:27 -0700
Subject: [PATCH 4/5] fixup! clang-format

---
 llvm/lib/Target/RISCV/RISCVOptZba.cpp | 28 ++++++++++++++++++++-------
 1 file changed, 21 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVOptZba.cpp b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
index 73505f48749b82..5ceb2e0269e030 100644
--- a/llvm/lib/Target/RISCV/RISCVOptZba.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
@@ -85,9 +85,15 @@ bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
       switch (MI.getOpcode()) {
       default:
         continue;
-      case RISCV::SH1ADD: OuterShiftAmt = 1; break;
-      case RISCV::SH2ADD: OuterShiftAmt = 2; break;
-      case RISCV::SH3ADD: OuterShiftAmt = 3; break;
+      case RISCV::SH1ADD:
+        OuterShiftAmt = 1;
+        break;
+      case RISCV::SH2ADD:
+        OuterShiftAmt = 2;
+        break;
+      case RISCV::SH3ADD:
+        OuterShiftAmt = 3;
+        break;
       }
 
       // Second operand must be virtual.
@@ -123,10 +129,18 @@ bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
       switch (InnerShiftAmt - OuterShiftAmt) {
       default:
         continue;
-      case 0: InnerOpc = RISCV::ADD;    break;
-      case 1: InnerOpc = RISCV::SH1ADD; break;
-      case 2: InnerOpc = RISCV::SH2ADD; break;
-      case 3: InnerOpc = RISCV::SH3ADD; break;
+      case 0:
+        InnerOpc = RISCV::ADD;
+        break;
+      case 1:
+        InnerOpc = RISCV::SH1ADD;
+        break;
+      case 2:
+        InnerOpc = RISCV::SH2ADD;
+        break;
+      case 3:
+        InnerOpc = RISCV::SH3ADD;
+        break;
       }
 
       Register Y = InnerShift->getOperand(1).getReg();

>From cb16e3ea0d683c610a290a55b1b8a41930c7910a Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 5 Apr 2024 09:42:30 -0700
Subject: [PATCH 5/5] fixup! address review comments.

---
 llvm/lib/Target/RISCV/RISCVOptZba.cpp | 34 +++++++++++++++------------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVOptZba.cpp b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
index 5ceb2e0269e030..b61e79140f9bc5 100644
--- a/llvm/lib/Target/RISCV/RISCVOptZba.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptZba.cpp
@@ -6,16 +6,20 @@
 //
 //===---------------------------------------------------------------------===//
 //
-// This pass attempts to reassociate expressions like
-//   (sh3add Z, (add X, (slli Y, 6)))
+// This pass reassociates expressions like
+//   (sh3add Z, (add X, (slli Y, 5)))
 // To
-//   (sh3add (sh3add Y, Z), X)
+//   (sh3add (sh2add Y, Z), X)
 //
-// This reduces number of instructions needing by spreading the shift amount
-// across to 2 sh3adds. This can be generalized to sh1add/sh2add and other
-// shift amounts.
+// If the shift amount is small enough. The outer shXadd keeps its original
+// opcode. The inner shXadd shift amount is the difference between the slli
+// shift amount and the outer shXadd shift amount.
 //
-// TODO: We can also support slli.uw by using shXadd.uw or the inner shXadd.
+// This pattern can appear when indexing a two dimensional array, but it is not
+// limited to that.
+//
+// TODO: We can also support slli.uw by using shXadd.uw for the inner shXadd.
+// TODO: This can be generalized to deeper expressions.
 //
 //===---------------------------------------------------------------------===//
 
@@ -53,17 +57,17 @@ INITIALIZE_PASS(RISCVOptZba, DEBUG_TYPE, RISCV_OPT_ZBA_NAME, false, false)
 
 FunctionPass *llvm::createRISCVOptZbaPass() { return new RISCVOptZba(); }
 
-static bool findShift(Register Reg, const MachineBasicBlock &MBB,
-                      MachineInstr *&Shift, const MachineRegisterInfo &MRI) {
+static MachineInstr *findShift(Register Reg, const MachineBasicBlock &MBB,
+                               MachineRegisterInfo &MRI) {
   if (!Reg.isVirtual())
-    return false;
+    return nullptr;
 
-  Shift = MRI.getVRegDef(Reg);
+  MachineInstr *Shift = MRI.getVRegDef(Reg);
   if (!Shift || Shift->getOpcode() != RISCV::SLLI ||
       Shift->getParent() != &MBB || !MRI.hasOneNonDBGUse(Reg))
-    return false;
+    return nullptr;
 
-  return true;
+  return Shift;
 }
 
 bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
@@ -111,9 +115,9 @@ bool RISCVOptZba::runOnMachineFunction(MachineFunction &MF) {
 
       MachineInstr *InnerShift;
       Register X;
-      if (findShift(AddReg0, MBB, InnerShift, MRI))
+      if ((InnerShift = findShift(AddReg0, MBB, MRI)))
         X = AddReg1;
-      else if (findShift(AddReg1, MBB, InnerShift, MRI))
+      else if ((InnerShift = findShift(AddReg1, MBB, MRI)))
         X = AddReg0;
       else
         continue;



More information about the llvm-commits mailing list