[clang] [X86] Allow PSHUFD/PSHUFLW/PSHUFW intrinsics in constexpr. (PR #161210)

Timm Baeder via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 29 08:00:10 PDT 2025


================
@@ -2862,6 +2862,218 @@ static bool interp__builtin_blend(InterpState &S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_ia32_pshuflw_common(InterpState &S, CodePtr OpPC,
+                                                const CallExpr *Call) {
+  const unsigned NumArgs = Call->getNumArgs();
+  assert(NumArgs == 2 || NumArgs == 3 || NumArgs == 4);
+  APSInt K;
+  Pointer SrcPT;
+  const bool HasMask = (NumArgs == 3) || (NumArgs == 4);
+  const bool IsMaskZ = (NumArgs == 3);
+  if (NumArgs == 4) {
+    K = popToAPSInt(S, Call->getArg(3));
+    SrcPT = S.Stk.pop<Pointer>();
+  } else if (NumArgs == 3) {
+    K = popToAPSInt(S, Call->getArg(2));
+  }
+
+  APSInt Imm = popToAPSInt(S, Call->getArg(1));
+  const Pointer &Src = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+  const unsigned NumElems = Dst.getNumElems();
+  const PrimType ElemT = Dst.getFieldDesc()->getPrimType();
+  const unsigned ElemBits = 16;
+  const unsigned LaneElems = 128u / ElemBits;
+  const unsigned Half = 4;
+  assert(NumElems % LaneElems == 0 && "pshuflw expects 128-bit lanes");
+  const uint8_t Ctl = static_cast<uint8_t>(Imm.getZExtValue());
+
+  for (unsigned i = 0; i != NumElems; ++i) {
+    const unsigned laneBase = (i / LaneElems) * LaneElems;
+    const unsigned inLane = i % LaneElems;
+
+    unsigned srcIdx;
+    if (inLane < Half) {
+      const unsigned pos = inLane;
+      const unsigned sel = (Ctl >> (2 * pos)) & 0x3;
+      srcIdx = laneBase + sel;
+    } else {
+      srcIdx = i;
+    }
+
+    APSInt Chosen;
+    INT_TYPE_SWITCH(ElemT, { Chosen = Src.elem<T>(srcIdx).toAPSInt(); });
+
+    if (!HasMask) {
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Chosen); });
+      continue;
+    }
+
+    const bool Keep =
+        (i < static_cast<unsigned>(K.getBitWidth())) ? K[i] : false;
+
+    if (Keep) {
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Chosen); });
+    } else if (IsMaskZ) {
+      APSInt Zero(APInt(Chosen.getBitWidth(), 0));
+      Zero.setIsSigned(Chosen.isSigned());
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Zero); });
+    } else {
+      APSInt PT;
+      INT_TYPE_SWITCH(ElemT, { PT = SrcPT.elem<T>(i).toAPSInt(); });
+      INT_TYPE_SWITCH_NO_BOOL(ElemT, { Dst.elem<T>(i) = static_cast<T>(PT); });
+    }
+  }
+
+  Dst.initializeAllElements();
+  return true;
+}
+
+static bool interp__builtin_ia32_pshufhw_common(InterpState &S, CodePtr OpPC,
+                                                const CallExpr *Call) {
+  (void)OpPC;
+  const unsigned NumArgs = Call->getNumArgs();
+  assert(NumArgs == 2 || NumArgs == 3 || NumArgs == 4);
+
+  APSInt K;
+  Pointer SrcPT;
+  const bool HasMask = (NumArgs == 3) || (NumArgs == 4);
+  const bool IsMaskZ = (NumArgs == 3);
+
+  if (NumArgs == 4) {
+    K = popToAPSInt(S, Call->getArg(3));
+    SrcPT = S.Stk.pop<Pointer>();
+  } else if (NumArgs == 3) {
+    K = popToAPSInt(S, Call->getArg(2));
+  }
+
+  APSInt Imm = popToAPSInt(S, Call->getArg(1));
+  const Pointer &Src = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+
+  const unsigned NumElems = Dst.getNumElems();
+  const PrimType ElemT = Dst.getFieldDesc()->getPrimType();
+
+  const unsigned ElemBits = 16;
+  const unsigned LaneElems = 128u / ElemBits;
+  const unsigned HalfBase = 4;
+  assert(NumElems % LaneElems == 0);
+
+  const uint8_t Ctl = static_cast<uint8_t>(Imm.getZExtValue());
+
+  for (unsigned i = 0; i != NumElems; ++i) {
+    const unsigned laneBase = (i / LaneElems) * LaneElems;
+    const unsigned inLane = i % LaneElems;
+
+    unsigned srcIdx;
+    if (inLane >= HalfBase) {
+      const unsigned pos = inLane - HalfBase;
+      const unsigned sel = (Ctl >> (2 * pos)) & 0x3;
+      srcIdx = laneBase + HalfBase + sel;
+    } else {
+      srcIdx = i;
+    }
+
+    APSInt Chosen;
+    INT_TYPE_SWITCH(ElemT, { Chosen = Src.elem<T>(srcIdx).toAPSInt(); });
+
+    if (!HasMask) {
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Chosen); });
+      continue;
+    }
+
+    const bool Keep =
+        (i < static_cast<unsigned>(K.getBitWidth())) ? K[i] : false;
+    if (Keep) {
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Chosen); });
+    } else if (IsMaskZ) {
+      APSInt Zero(APInt(Chosen.getBitWidth(), 0));
+      Zero.setIsSigned(Chosen.isSigned());
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(i) = static_cast<T>(Zero); });
+    } else {
+      APSInt PT;
+      INT_TYPE_SWITCH(ElemT, { PT = SrcPT.elem<T>(i).toAPSInt(); });
+      INT_TYPE_SWITCH_NO_BOOL(ElemT, { Dst.elem<T>(i) = static_cast<T>(PT); });
+    }
+  }
+
+  Dst.initializeAllElements();
+  return true;
+}
+
+static bool interp__builtin_ia32_pshufd_common(InterpState &S, CodePtr OpPC,
+                                              const CallExpr *Call) {
+  (void)OpPC;
+  const unsigned NumArgs = Call->getNumArgs();
+  assert(NumArgs == 2 || NumArgs == 3 || NumArgs == 4);
+
+  APSInt K;
+  Pointer SrcPT;
----------------
tbaederr wrote:

What does "PT" mean?

https://github.com/llvm/llvm-project/pull/161210


More information about the cfe-commits mailing list