[llvm] 0cf768e - [RISCV] Handle disjoint or in RISCVGatherScatterLowering (#77800)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 14 22:37:13 PST 2024


Author: Luke Lau
Date: 2024-01-15T13:37:09+07:00
New Revision: 0cf768e7f12dfb581fbae40a3b9b77f6c4533c29

URL: https://github.com/llvm/llvm-project/commit/0cf768e7f12dfb581fbae40a3b9b77f6c4533c29
DIFF: https://github.com/llvm/llvm-project/commit/0cf768e7f12dfb581fbae40a3b9b77f6c4533c29.diff

LOG: [RISCV] Handle disjoint or in RISCVGatherScatterLowering (#77800)

This patch adds support for the disjoint flag in the non-recursive case,
as well as adding an additional check for it in the recursive case. Note
that haveNoCommonBitsSet should be equivalent to having the disjoint
flag set, and the check can be removed in a follow-up patch.

Co-authored-by: Philip Reames <preames at rivosinc.com>

---------

Co-authored-by: Philip Reames <preames at rivosinc.com>

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1129206800ad36..cd438e153068e5 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
   if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Or &&
               BO->getOpcode() != Instruction::Shl &&
               BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
+  if (BO->getOpcode() == Instruction::Or &&
+      !cast<PossiblyDisjointInst>(BO)->isDisjoint())
+    return std::make_pair(nullptr, nullptr);
+
   // Look for an operand that is splatted.
   unsigned OtherIndex = 0;
   Value *Splat = getSplatValue(BO->getOperand(1));
@@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   switch (BO->getOpcode()) {
   default:
     llvm_unreachable("Unexpected opcode");
+  case Instruction::Or:
+    // TODO: We'd be better off creating disjoint or here, but we don't yet
+    // have an IRBuilder API for that.
+    [[fallthrough]];
   case Instruction::Add:
     Start = Builder.CreateAdd(Start, Splat);
     break;
@@ -241,7 +250,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
     return false;
   case Instruction::Or:
     // We need to be able to treat Or as Add.
-    if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+    if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL) &&
+        !cast<PossiblyDisjointInst>(BO)->isDisjoint())
       return false;
     break;
   case Instruction::Add:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 838089baa46fc4..54e5d39e248544 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
 
 define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
 ; CHECK-LABEL: @straightline_offset_disjoint_or(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()


        


More information about the llvm-commits mailing list