[llvm] [RISCV] Handle disjoint or in RISCVGatherScatterLowering (PR #77800)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 09:19:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

This patch adds support for the disjoint flag in the non-recursive case, but
for the recursive case we were already handling this by checking that there
were no common bits. This patch replaces that check with a check for the
disjoint flag instead, since instcombine will already compute it for us.

Co-authored-by: Philip Reames <preames@<!-- -->rivosinc.com>


---
Full diff: https://github.com/llvm/llvm-project/pull/77800.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (+10-1) 
- (modified) llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll (+2-5) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1129206800ad36..1dcb83a6078ed7 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
   if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Or &&
               BO->getOpcode() != Instruction::Shl &&
               BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
+  if (BO->getOpcode() == Instruction::Or &&
+      !cast<PossiblyDisjointInst>(BO)->isDisjoint())
+    return std::make_pair(nullptr, nullptr);
+
   // Look for an operand that is splatted.
   unsigned OtherIndex = 0;
   Value *Splat = getSplatValue(BO->getOperand(1));
@@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   switch (BO->getOpcode()) {
   default:
     llvm_unreachable("Unexpected opcode");
+  case Instruction::Or:
+    // TODO: We'd be better off creating disjoint or here, but we don't yet
+    // have an IRBuilder API for that.
+    [[fallthrough]];
   case Instruction::Add:
     Start = Builder.CreateAdd(Start, Splat);
     break;
@@ -241,7 +250,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
     return false;
   case Instruction::Or:
     // We need to be able to treat Or as Add.
-    if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+    if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
       return false;
     break;
   case Instruction::Add:
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 838089baa46fc4..54e5d39e248544 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
 
 define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
 ; CHECK-LABEL: @straightline_offset_disjoint_or(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()

``````````

</details>


https://github.com/llvm/llvm-project/pull/77800


More information about the llvm-commits mailing list