[llvm] [RISCV] Handle disjoint or in RISCVGatherScatterLowering (PR #77800)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 14 20:57:59 PST 2024
https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/77800
>From 9199ee0ab8c7a7bf599c96092a6045e46fb4c557 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 12 Jan 2024 00:11:49 +0700
Subject: [PATCH 1/2] [RISCV] Handle disjoint or in RISCVGatherScatterLowering
This patch adds support for the disjoint flag in the non-recursive case, but
for the recursive case we were already handling this by checking that there
were no common bits. This patch replaces that check with a check for the
disjoint flag instead, since instcombine will already compute it for us.
Co-authored-by: Philip Reames <preames at rivosinc.com>
---
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp | 11 ++++++++++-
llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll | 7 ++-----
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1129206800ad36..1dcb83a6078ed7 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || (BO->getOpcode() != Instruction::Add &&
+ BO->getOpcode() != Instruction::Or &&
BO->getOpcode() != Instruction::Shl &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);
+ if (BO->getOpcode() == Instruction::Or &&
+ !cast<PossiblyDisjointInst>(BO)->isDisjoint())
+ return std::make_pair(nullptr, nullptr);
+
// Look for an operand that is splatted.
unsigned OtherIndex = 0;
Value *Splat = getSplatValue(BO->getOperand(1));
@@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
switch (BO->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
+ case Instruction::Or:
+ // TODO: We'd be better off creating disjoint or here, but we don't yet
+ // have an IRBuilder API for that.
+ [[fallthrough]];
case Instruction::Add:
Start = Builder.CreateAdd(Start, Splat);
break;
@@ -241,7 +250,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
return false;
case Instruction::Or:
// We need to be able to treat Or as Add.
- if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+ if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
break;
case Instruction::Add:
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 838089baa46fc4..54e5d39e248544 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
; CHECK-LABEL: @straightline_offset_disjoint_or(
-; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT: [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
-; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
>From bb9a5872ca8e80a9ac8d49b75f8f71a3647f26d2 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 15 Jan 2024 11:56:54 +0700
Subject: [PATCH 2/2] Keep haveNoCommonBitsSet check
---
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1dcb83a6078ed7..cd438e153068e5 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -250,7 +250,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
return false;
case Instruction::Or:
// We need to be able to treat Or as Add.
- if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
+ if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL) &&
+ !cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
break;
case Instruction::Add:
More information about the llvm-commits
mailing list