[Mlir-commits] [mlir] [InstCombine] Refactor matchFunnelShift to allow more pattern (NFC) (PR #68474)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 7 20:26:16 PDT 2023


https://github.com/HaohaiWen updated https://github.com/llvm/llvm-project/pull/68474

>From 5b3b1bbb5b263bc5711adde031d85b1461ccbab6 Mon Sep 17 00:00:00 2001
From: Haohai Wen <haohai.wen at intel.com>
Date: Sat, 7 Oct 2023 13:48:32 +0800
Subject: [PATCH] [InstCombine] Refactor matchFunnelShift to allow more pattern
 (NFC)

Current implementation of matchFunnelShift only allows opposite shift
pattern. Refactor it to allow more pattern.
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 172 ++++++++++--------
 1 file changed, 93 insertions(+), 79 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index cbdab3e9c5fb91d..b04e070fd19d7d1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2732,100 +2732,114 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
   // rotate matching code under visitSelect and visitTrunc?
   unsigned Width = Or.getType()->getScalarSizeInBits();
 
-  // First, find an or'd pair of opposite shifts:
-  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
-  BinaryOperator *Or0, *Or1;
-  if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
-      !match(Or.getOperand(1), m_BinOp(Or1)))
-    return nullptr;
-
-  Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
-  if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
-      !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
-      Or0->getOpcode() == Or1->getOpcode())
+  Instruction *Or0, *Or1;
+  if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
+      !match(Or.getOperand(1), m_Instruction(Or1)))
     return nullptr;
 
-  // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
-  if (Or0->getOpcode() == BinaryOperator::LShr) {
-    std::swap(Or0, Or1);
-    std::swap(ShVal0, ShVal1);
-    std::swap(ShAmt0, ShAmt1);
-  }
-  assert(Or0->getOpcode() == BinaryOperator::Shl &&
-         Or1->getOpcode() == BinaryOperator::LShr &&
-         "Illegal or(shift,shift) pair");
-
-  // Match the shift amount operands for a funnel shift pattern. This always
-  // matches a subtraction on the R operand.
-  auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
-    // Check for constant shift amounts that sum to the bitwidth.
-    const APInt *LI, *RI;
-    if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
-      if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
-        return ConstantInt::get(L->getType(), *LI);
-
-    Constant *LC, *RC;
-    if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
-        match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
-      return ConstantExpr::mergeUndefsWith(LC, RC);
-
-    // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
-    // We limit this to X < Width in case the backend re-expands the intrinsic,
-    // and has to reintroduce a shift modulo operation (InstCombine might remove
-    // it after this fold). This still doesn't guarantee that the final codegen
-    // will match this original pattern.
-    if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
-      KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
-      return KnownL.getMaxValue().ult(Width) ? L : nullptr;
-    }
+  bool IsFshl = true; // Sub on LSHR.
+  SmallVector<Value *, 3> FShiftArgs;
 
-    // For non-constant cases, the following patterns currently only work for
-    // rotation patterns.
-    // TODO: Add general funnel-shift compatible patterns.
-    if (ShVal0 != ShVal1)
+  // First, find an or'd pair of opposite shifts:
+  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
+  if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
+    Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+    if (!match(Or0,
+               m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+        !match(Or1,
+               m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+        Or0->getOpcode() == Or1->getOpcode())
       return nullptr;
 
-    // For non-constant cases we don't support non-pow2 shift masks.
-    // TODO: Is it worth matching urem as well?
-    if (!isPowerOf2_32(Width))
-      return nullptr;
+    // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+    if (Or0->getOpcode() == BinaryOperator::LShr) {
+      std::swap(Or0, Or1);
+      std::swap(ShVal0, ShVal1);
+      std::swap(ShAmt0, ShAmt1);
+    }
+    assert(Or0->getOpcode() == BinaryOperator::Shl &&
+           Or1->getOpcode() == BinaryOperator::LShr &&
+           "Illegal or(shift,shift) pair");
+
+    // Match the shift amount operands for a funnel shift pattern. This always
+    // matches a subtraction on the R operand.
+    auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+      // Check for constant shift amounts that sum to the bitwidth.
+      const APInt *LI, *RI;
+      if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+        if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+          return ConstantInt::get(L->getType(), *LI);
+
+      Constant *LC, *RC;
+      if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+          match(L,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(R,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+        return ConstantExpr::mergeUndefsWith(LC, RC);
+
+      // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+      // We limit this to X < Width in case the backend re-expands the
+      // intrinsic, and has to reintroduce a shift modulo operation (InstCombine
+      // might remove it after this fold). This still doesn't guarantee that the
+      // final codegen will match this original pattern.
+      if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+        KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+        return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+      }
 
-    // The shift amount may be masked with negation:
-    // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
-    Value *X;
-    unsigned Mask = Width - 1;
-    if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
-        match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
-      return X;
+      // For non-constant cases, the following patterns currently only work for
+      // rotation patterns.
+      // TODO: Add general funnel-shift compatible patterns.
+      if (ShVal0 != ShVal1)
+        return nullptr;
 
-    // Similar to above, but the shift amount may be extended after masking,
-    // so return the extended value as the parameter for the intrinsic.
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
-                       m_SpecificInt(Mask))))
-      return L;
+      // For non-constant cases we don't support non-pow2 shift masks.
+      // TODO: Is it worth matching urem as well?
+      if (!isPowerOf2_32(Width))
+        return nullptr;
 
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
-      return L;
+      // The shift amount may be masked with negation:
+      // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+      Value *X;
+      unsigned Mask = Width - 1;
+      if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
+          match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
+        return X;
+
+      // Similar to above, but the shift amount may be extended after masking,
+      // so return the extended value as the parameter for the intrinsic.
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R,
+                m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
+                      m_SpecificInt(Mask))))
+        return L;
+
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+        return L;
 
-    return nullptr;
-  };
+      return nullptr;
+    };
 
-  Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
-  bool IsFshl = true; // Sub on LSHR.
-  if (!ShAmt) {
-    ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
-    IsFshl = false; // Sub on SHL.
+    Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
+    if (!ShAmt) {
+      ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
+      IsFshl = false; // Sub on SHL.
+    }
+    if (!ShAmt)
+      return nullptr;
+
+    FShiftArgs = {ShVal0, ShVal1, ShAmt};
   }
-  if (!ShAmt)
+
+  if (FShiftArgs.empty())
     return nullptr;
 
   Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
   Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
-  return CallInst::Create(F, {ShVal0, ShVal1, ShAmt});
+  return CallInst::Create(F, FShiftArgs);
 }
 
 /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.



More information about the Mlir-commits mailing list