[llvm] [LLVM][InstCombine] Extend masked_gather's demanded elt analysis. (PR #151732)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 2 03:59:14 PDT 2025


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/151732

>From 0dc732da1505ef7c7ea2d6144954f721ebcd20d0 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 1 Aug 2025 10:17:40 +0000
Subject: [PATCH 1/2] [LLVM][InstCombine] Extend masked_gather's demanded elt
 analysis.

Add support for other Constant types for the mask operand.
---
 .../InstCombineSimplifyDemanded.cpp           | 22 +++++++++++++------
 .../InstCombine/masked_intrinsics.ll          |  1 +
 2 files changed, 16 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 0e3436d12702d..c82dae7ac6e65 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1834,14 +1834,22 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
       // segfaults which didn't exist in the original program.
       APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
           DemandedPassThrough(DemandedElts);
-      if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
-        for (unsigned i = 0; i < VWidth; i++) {
-          Constant *CElt = CV->getAggregateElement(i);
-          if (CElt->isNullValue())
-            DemandedPtrs.clearBit(i);
-          else if (CElt->isAllOnesValue())
-            DemandedPassThrough.clearBit(i);
+      if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) {
+        if (CMask->isNullValue())
+          DemandedPtrs.clearAllBits();
+        else if (CMask->isAllOnesValue())
+          DemandedPassThrough.clearAllBits();
+        else if (auto *CV = dyn_cast<ConstantVector>(CMask)) {
+          for (unsigned i = 0; i < VWidth; i++) {
+            Constant *CElt = CV->getAggregateElement(i);
+            if (CElt->isNullValue())
+              DemandedPtrs.clearBit(i);
+            else if (CElt->isAllOnesValue())
+              DemandedPassThrough.clearBit(i);
+          }
         }
+      }
+
       if (II->getIntrinsicID() == Intrinsic::masked_gather)
         simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2);
       simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3);
diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index d9f022442a02e..8f7683419a82a 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+; RUN: opt -passes=instcombine -use-constant-int-for-fixed-length-splat -S < %s | FileCheck %s
 
 declare <2 x double> @llvm.masked.load.v2f64.p0(ptr %ptrs, i32, <2 x i1> %mask, <2 x double> %src0)
 declare void @llvm.masked.store.v2f64.p0(<2 x double> %val, ptr %ptrs, i32, <2 x i1> %mask)

>From efd0af43f81b6997272b353339d07f06f5e541e1 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Sat, 2 Aug 2025 11:58:27 +0100
Subject: [PATCH 2/2] Simplify code by iterating across all Constant types.

---
 .../InstCombine/InstCombineSimplifyDemanded.cpp          | 9 ++-------
 1 file changed, 2 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index c82dae7ac6e65..f17fecd430a6c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1835,13 +1835,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
       APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
           DemandedPassThrough(DemandedElts);
       if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) {
-        if (CMask->isNullValue())
-          DemandedPtrs.clearAllBits();
-        else if (CMask->isAllOnesValue())
-          DemandedPassThrough.clearAllBits();
-        else if (auto *CV = dyn_cast<ConstantVector>(CMask)) {
-          for (unsigned i = 0; i < VWidth; i++) {
-            Constant *CElt = CV->getAggregateElement(i);
+        for (unsigned i = 0; i < VWidth; i++) {
+          if (Constant *CElt = CMask->getAggregateElement(i)) {
             if (CElt->isNullValue())
               DemandedPtrs.clearBit(i);
             else if (CElt->isAllOnesValue())



More information about the llvm-commits mailing list