[llvm] [InstCombine] Allow folds of shifts by constants for scalable vectors again (PR #132522)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 23 21:18:33 PDT 2025


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/132522

>From 5cb140ce08525ecc357af73f63ed8d9275d799c9 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Sat, 22 Mar 2025 12:33:08 +0800
Subject: [PATCH 1/3] Precommit tests

---
 llvm/test/Transforms/InstCombine/shl-bo.ll            | 11 +++++++++++
 .../test/Transforms/InstCombine/shl-twice-constant.ll | 11 +++++++++++
 2 files changed, 22 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/shl-bo.ll b/llvm/test/Transforms/InstCombine/shl-bo.ll
index c32ac2eacb25a..03b0080fb7c37 100644
--- a/llvm/test/Transforms/InstCombine/shl-bo.ll
+++ b/llvm/test/Transforms/InstCombine/shl-bo.ll
@@ -656,3 +656,14 @@ define <16 x i8> @test_FoldShiftByConstant_CreateAnd(<16 x i8> %in0) {
   %vshl_n = shl <16 x i8> %tmp, <i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5>
   ret <16 x i8> %vshl_n
 }
+
+define <vscale x 1 x i8> @test_FoldShiftByConstant_CreateAnd_scalable(<vscale x 1 x i8> %x) {
+; CHECK-LABEL: @test_FoldShiftByConstant_CreateAnd_scalable(
+; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i8> [[X:%.*]], splat (i8 2)
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw <vscale x 1 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP2]]
+;
+  %1 = and <vscale x 1 x i8> %x, splat (i8 2)
+  %2 = shl <vscale x 1 x i8> %1, splat (i8 2)
+  ret <vscale x 1 x i8> %2
+}
diff --git a/llvm/test/Transforms/InstCombine/shl-twice-constant.ll b/llvm/test/Transforms/InstCombine/shl-twice-constant.ll
index bbdd7fa3d1c40..151db29fe3e5f 100644
--- a/llvm/test/Transforms/InstCombine/shl-twice-constant.ll
+++ b/llvm/test/Transforms/InstCombine/shl-twice-constant.ll
@@ -14,3 +14,14 @@ define i64 @testfunc() {
   %shl2 = shl i64 %shl1, ptrtoint (ptr @c to i64)
   ret i64 %shl2
 }
+
+define <vscale x 1 x i64> @scalable() {
+; CHECK-LABEL: @scalable(
+; CHECK-NEXT:    [[SHL1:%.*]] = shl nuw <vscale x 1 x i64> splat (i64 1), shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 ptrtoint (ptr @c2 to i64), i64 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
+; CHECK-NEXT:    [[SHL2:%.*]] = shl <vscale x 1 x i64> [[SHL1]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 ptrtoint (ptr @c to i64), i64 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[SHL2]]
+;
+  %shl1 = shl <vscale x 1 x i64> splat (i64 1), splat (i64 ptrtoint (ptr @c2 to i64))
+  %shl2 = shl <vscale x 1 x i64> %shl1, splat (i64 ptrtoint (ptr @c to i64))
+  ret <vscale x 1 x i64> %shl2
+}

>From 25172e1cc53821fe95c37569c3e8f69f5d4ba9ec Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Sat, 22 Mar 2025 12:57:00 +0800
Subject: [PATCH 2/3] [InstCombine] Allow folds of shifts by constants for
 scalable vectors again

However this meant that FoldShiftByConstant no longer kicked in for scalable vectors because scalable splats are represented by ConstantExprs.

This fixes it by explicitly allowing splats of ConstantInts, it's not the prettiest so open to any suggestions.

But I'm also hoping that UseConstantIntForScalableSplat will eventually remove the need for this.

I noticed this when trying to reverse a combine on RISC-V in #132245, and saw that the resulting vector and scalar forms were different.
---
 llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 5 ++++-
 llvm/test/Transforms/InstCombine/shl-bo.ll            | 4 ++--
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 90cd279e8a457..91174cc79cd2b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -428,7 +428,10 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
         return R;
 
   Constant *CUI;
-  if (match(Op1, m_ImmConstant(CUI)))
+  if (match(Op1, m_Constant(CUI)) &&
+      (!isa<ConstantExpr>(CUI) ||
+       (Ty->isVectorTy() &&
+        isa_and_present<ConstantInt>(CUI->getSplatValue()))))
     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
       return Res;
 
diff --git a/llvm/test/Transforms/InstCombine/shl-bo.ll b/llvm/test/Transforms/InstCombine/shl-bo.ll
index 03b0080fb7c37..5ee8716d5d119 100644
--- a/llvm/test/Transforms/InstCombine/shl-bo.ll
+++ b/llvm/test/Transforms/InstCombine/shl-bo.ll
@@ -659,8 +659,8 @@ define <16 x i8> @test_FoldShiftByConstant_CreateAnd(<16 x i8> %in0) {
 
 define <vscale x 1 x i8> @test_FoldShiftByConstant_CreateAnd_scalable(<vscale x 1 x i8> %x) {
 ; CHECK-LABEL: @test_FoldShiftByConstant_CreateAnd_scalable(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i8> [[X:%.*]], splat (i8 2)
-; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw <vscale x 1 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <vscale x 1 x i8> [[X:%.*]], splat (i8 2)
+; CHECK-NEXT:    [[TMP2:%.*]] = and <vscale x 1 x i8> [[TMP1]], splat (i8 8)
 ; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP2]]
 ;
   %1 = and <vscale x 1 x i8> %x, splat (i8 2)

>From a0f4ac6e2ef13f90ce44524a6943f98583e43367 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 24 Mar 2025 12:14:35 +0800
Subject: [PATCH 3/3] Rework to adjust m_ImmConstant

---
 llvm/include/llvm/IR/PatternMatch.h           | 51 +++++++++++++++----
 .../InstCombine/InstCombineShifts.cpp         |  5 +-
 llvm/test/Transforms/InstCombine/select.ll    |  3 +-
 llvm/test/Transforms/InstCombine/sub.ll       |  2 +-
 4 files changed, 45 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index b3eeb1d7ba88a..ff3f7735bfa85 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -858,18 +858,51 @@ inline bind_ty<const BasicBlock> m_BasicBlock(const BasicBlock *&V) {
   return V;
 }
 
+// TODO: Remove once UseConstant{Int,FP}ForScalableSplat is enabled by default,
+// and use m_Unless(m_ConstantExpr).
+struct immconstant_ty {
+  template <typename ITy> static bool isImmConstant(ITy *V) {
+    if (auto *CV = dyn_cast<Constant>(V)) {
+      if (!isa<ConstantExpr>(CV) && !CV->containsConstantExpression())
+        return true;
+
+      if (CV->getType()->isVectorTy()) {
+        if (auto *Splat = CV->getSplatValue(/* AllowPoison */ true)) {
+          if (!isa<ConstantExpr>(Splat) &&
+              !Splat->containsConstantExpression()) {
+            return true;
+          }
+        }
+      }
+    }
+    return false;
+  }
+};
+
+struct match_immconstant_ty : immconstant_ty {
+  template <typename ITy> bool match(ITy *V) { return isImmConstant(V); }
+};
+
 /// Match an arbitrary immediate Constant and ignore it.
-inline match_combine_and<class_match<Constant>,
-                         match_unless<constantexpr_match>>
-m_ImmConstant() {
-  return m_CombineAnd(m_Constant(), m_Unless(m_ConstantExpr()));
-}
+inline match_immconstant_ty m_ImmConstant() { return match_immconstant_ty(); }
+
+struct bind_immconstant_ty : immconstant_ty {
+  Constant *&VR;
+
+  bind_immconstant_ty(Constant *&V) : VR(V) {}
+
+  template <typename ITy> bool match(ITy *V) {
+    if (isImmConstant(V)) {
+      VR = cast<Constant>(V);
+      return true;
+    }
+    return false;
+  }
+};
 
 /// Match an immediate Constant, capturing the value if we match.
-inline match_combine_and<bind_ty<Constant>,
-                         match_unless<constantexpr_match>>
-m_ImmConstant(Constant *&C) {
-  return m_CombineAnd(m_Constant(C), m_Unless(m_ConstantExpr()));
+inline bind_immconstant_ty m_ImmConstant(Constant *&C) {
+  return bind_immconstant_ty(C);
 }
 
 /// Match a specified Value*.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 91174cc79cd2b..90cd279e8a457 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -428,10 +428,7 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
         return R;
 
   Constant *CUI;
-  if (match(Op1, m_Constant(CUI)) &&
-      (!isa<ConstantExpr>(CUI) ||
-       (Ty->isVectorTy() &&
-        isa_and_present<ConstantInt>(CUI->getSplatValue()))))
+  if (match(Op1, m_ImmConstant(CUI)))
     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
       return Res;
 
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 2078b795817f8..3d81b72dd232e 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3519,8 +3519,7 @@ define <vscale x 2 x i32> @scalable_sign_bits(<vscale x 2 x i8> %x) {
 
 define <vscale x 2 x i1> @scalable_non_zero(<vscale x 2 x i32> %x) {
 ; CHECK-LABEL: @scalable_non_zero(
-; CHECK-NEXT:    [[A:%.*]] = or <vscale x 2 x i32> [[X:%.*]], splat (i32 1)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <vscale x 2 x i32> [[A]], splat (i32 57)
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <vscale x 2 x i32> [[X:%.*]], splat (i32 56)
 ; CHECK-NEXT:    ret <vscale x 2 x i1> [[CMP]]
 ;
   %a = or <vscale x 2 x i32> %x, splat (i32 1)
diff --git a/llvm/test/Transforms/InstCombine/sub.ll b/llvm/test/Transforms/InstCombine/sub.ll
index e89419d1f3838..534768eb3394f 100644
--- a/llvm/test/Transforms/InstCombine/sub.ll
+++ b/llvm/test/Transforms/InstCombine/sub.ll
@@ -861,7 +861,7 @@ define <2 x i16> @test44vecminval(<2 x i16> %x) {
 ; uses m_ImmConstant which matches Constant but (explicitly) not ConstantExpr.
 define <vscale x 2 x i16> @test44scalablevecminval(<vscale x 2 x i16> %x) {
 ; CHECK-LABEL: @test44scalablevecminval(
-; CHECK-NEXT:    [[SUB:%.*]] = add <vscale x 2 x i16> [[X:%.*]], splat (i16 -32768)
+; CHECK-NEXT:    [[SUB:%.*]] = xor <vscale x 2 x i16> [[X:%.*]], splat (i16 -32768)
 ; CHECK-NEXT:    ret <vscale x 2 x i16> [[SUB]]
 ;
   %sub = sub nsw <vscale x 2 x i16> %x, splat (i16 -32768)



More information about the llvm-commits mailing list