[llvm] 82c5d35 - [VPlan] Add commutative binary OR matcher, use in transform. (#92539)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 20 05:03:52 PDT 2024


Author: Florian Hahn
Date: 2024-05-20T13:03:48+01:00
New Revision: 82c5d350d200ccc5365d40eac187b9ec967af727

URL: https://github.com/llvm/llvm-project/commit/82c5d350d200ccc5365d40eac187b9ec967af727
DIFF: https://github.com/llvm/llvm-project/commit/82c5d350d200ccc5365d40eac187b9ec967af727.diff

LOG: [VPlan] Add commutative binary OR matcher, use in transform. (#92539)

Split off from https://github.com/llvm/llvm-project/pull/89386, this
extends the binary matcher to support matching commuative operations.
This is used for a new m_c_BinaryOr matcher, used in simplifyRecipe.

PR: https://github.com/llvm/llvm-project/pull/92539

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
    llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
    llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
    llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 56cbaa4201297..0587468807435 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -157,7 +157,7 @@ using AllUnaryRecipe_match =
     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
                       VPWidenCastRecipe, VPInstruction>;
 
-template <typename Op0_t, typename Op1_t, unsigned Opcode,
+template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
           typename... RecipeTys>
 struct BinaryRecipe_match {
   Op0_t Op0;
@@ -179,18 +179,23 @@ struct BinaryRecipe_match {
       return false;
     assert(R->getNumOperands() == 2 &&
            "recipe with matched opcode does not have 2 operands");
-    return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
+    if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)))
+      return true;
+    return Commutative && Op0.match(R->getOperand(1)) &&
+           Op1.match(R->getOperand(0));
   }
 };
 
 template <typename Op0_t, typename Op1_t, unsigned Opcode>
 using BinaryVPInstruction_match =
-    BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPInstruction>;
+    BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
+                       VPInstruction>;
 
-template <typename Op0_t, typename Op1_t, unsigned Opcode>
+template <typename Op0_t, typename Op1_t, unsigned Opcode,
+          bool Commutative = false>
 using AllBinaryRecipe_match =
-    BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
-                       VPWidenCastRecipe, VPInstruction>;
+    BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
+                       VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
 
 template <unsigned Opcode, typename Op0_t>
 inline UnaryVPInstruction_match<Op0_t, Opcode>
@@ -256,10 +261,11 @@ m_ZExtOrSExt(const Op0_t &Op0) {
   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
 }
 
-template <unsigned Opcode, typename Op0_t, typename Op1_t>
-inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode> m_Binary(const Op0_t &Op0,
-                                                            const Op1_t &Op1) {
-  return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
+template <unsigned Opcode, typename Op0_t, typename Op1_t,
+          bool Commutative = false>
+inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
+m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
+  return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
 }
 
 template <typename Op0_t, typename Op1_t>
@@ -268,10 +274,21 @@ m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
 }
 
-template <typename Op0_t, typename Op1_t>
-inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or>
+/// Match a binary OR operation. Note that while conceptually the operands can
+/// be matched commutatively, \p Commutative defaults to false in line with the
+/// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
+/// version of the matcher.
+template <typename Op0_t, typename Op1_t, bool Commutative = false>
+inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
-  return m_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
+  return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
+                             /*Commutative*/ true>
+m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
 }
 
 template <typename Op0_t, typename Op1_t>

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 4c968c2834b10..7ff8d8e0ea15d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -941,8 +941,8 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
   // recipes to be visited during simplification.
   VPValue *X, *Y, *X1, *Y1;
   if (match(&R,
-            m_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
-                       m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) &&
+            m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
+                         m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) &&
       X == X1 && Y == Y1) {
     R.getVPSingleValue()->replaceAllUsesWith(X);
     return;

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
index d335ac4b69709..200c2adcf0e66 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
@@ -402,10 +402,9 @@ define void @test_widen_if_then_else(ptr noalias %a, ptr readnone %b) #4 {
 ; TFCOMMON-NEXT:    [[TMP11:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> zeroinitializer, <vscale x 2 x i1> [[TMP10]])
 ; TFCOMMON-NEXT:    [[TMP12:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i1> [[TMP8]], <vscale x 2 x i1> zeroinitializer
 ; TFCOMMON-NEXT:    [[TMP13:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD]], <vscale x 2 x i1> [[TMP12]])
-; TFCOMMON-NEXT:    [[TMP14:%.*]] = or <vscale x 2 x i1> [[TMP10]], [[TMP12]]
 ; TFCOMMON-NEXT:    [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP10]], <vscale x 2 x i64> [[TMP11]], <vscale x 2 x i64> [[TMP13]]
 ; TFCOMMON-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i64, ptr [[B:%.*]], i64 [[INDEX]]
-; TFCOMMON-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP15]], i32 8, <vscale x 2 x i1> [[TMP14]])
+; TFCOMMON-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP15]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
 ; TFCOMMON-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP6]]
 ; TFCOMMON-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX_NEXT]], i64 1025)
 ; TFCOMMON-NEXT:    [[TMP16:%.*]] = xor <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer)
@@ -453,16 +452,14 @@ define void @test_widen_if_then_else(ptr noalias %a, ptr readnone %b) #4 {
 ; TFA_INTERLEAVE-NEXT:    [[TMP22:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK2]], <vscale x 2 x i1> [[TMP14]], <vscale x 2 x i1> zeroinitializer
 ; TFA_INTERLEAVE-NEXT:    [[TMP23:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD]], <vscale x 2 x i1> [[TMP21]])
 ; TFA_INTERLEAVE-NEXT:    [[TMP24:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD3]], <vscale x 2 x i1> [[TMP22]])
-; TFA_INTERLEAVE-NEXT:    [[TMP25:%.*]] = or <vscale x 2 x i1> [[TMP17]], [[TMP21]]
-; TFA_INTERLEAVE-NEXT:    [[TMP26:%.*]] = or <vscale x 2 x i1> [[TMP18]], [[TMP22]]
 ; TFA_INTERLEAVE-NEXT:    [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP17]], <vscale x 2 x i64> [[TMP19]], <vscale x 2 x i64> [[TMP23]]
 ; TFA_INTERLEAVE-NEXT:    [[PREDPHI4:%.*]] = select <vscale x 2 x i1> [[TMP18]], <vscale x 2 x i64> [[TMP20]], <vscale x 2 x i64> [[TMP24]]
 ; TFA_INTERLEAVE-NEXT:    [[TMP27:%.*]] = getelementptr inbounds i64, ptr [[B:%.*]], i64 [[INDEX]]
 ; TFA_INTERLEAVE-NEXT:    [[TMP28:%.*]] = call i64 @llvm.vscale.i64()
 ; TFA_INTERLEAVE-NEXT:    [[TMP29:%.*]] = mul i64 [[TMP28]], 2
 ; TFA_INTERLEAVE-NEXT:    [[TMP30:%.*]] = getelementptr inbounds i64, ptr [[TMP27]], i64 [[TMP29]]
-; TFA_INTERLEAVE-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP27]], i32 8, <vscale x 2 x i1> [[TMP25]])
-; TFA_INTERLEAVE-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI4]], ptr [[TMP30]], i32 8, <vscale x 2 x i1> [[TMP26]])
+; TFA_INTERLEAVE-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP27]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; TFA_INTERLEAVE-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI4]], ptr [[TMP30]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK2]])
 ; TFA_INTERLEAVE-NEXT:    [[INDEX_NEXT:%.*]] = add i64 [[INDEX]], [[TMP6]]
 ; TFA_INTERLEAVE-NEXT:    [[TMP31:%.*]] = call i64 @llvm.vscale.i64()
 ; TFA_INTERLEAVE-NEXT:    [[TMP32:%.*]] = mul i64 [[TMP31]], 2

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding.ll
index 2b2742ca7ccbc..63ad98b2d8ab2 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding.ll
@@ -480,11 +480,10 @@ define void @cond_uniform_load(ptr noalias %dst, ptr noalias readonly %src, ptr
 ; CHECK-NEXT:    [[TMP15:%.*]] = select <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i1> [[TMP14]], <vscale x 4 x i1> zeroinitializer
 ; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 4, <vscale x 4 x i1> [[TMP15]], <vscale x 4 x i32> poison)
 ; CHECK-NEXT:    [[TMP16:%.*]] = select <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i1> [[TMP13]], <vscale x 4 x i1> zeroinitializer
-; CHECK-NEXT:    [[TMP18:%.*]] = or <vscale x 4 x i1> [[TMP15]], [[TMP16]]
 ; CHECK-NEXT:    [[PREDPHI:%.*]] = select <vscale x 4 x i1> [[TMP16]], <vscale x 4 x i32> zeroinitializer, <vscale x 4 x i32> [[WIDE_MASKED_GATHER]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP10]]
 ; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 0
-; CHECK-NEXT:    call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[PREDPHI]], ptr [[TMP19]], i32 4, <vscale x 4 x i1> [[TMP18]])
+; CHECK-NEXT:    call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[PREDPHI]], ptr [[TMP19]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
 ; CHECK-NEXT:    [[INDEX_NEXT2]] = add i64 [[INDEX1]], [[TMP21]]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX1]], i64 [[TMP9]])
 ; CHECK-NEXT:    [[TMP22:%.*]] = xor <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer)


        


More information about the llvm-commits mailing list