[llvm] e13c84c - GlobalISel: Work on improving stock set of legality predicates

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 28 17:28:30 PDT 2020


Author: Matt Arsenault
Date: 2020-05-28T20:28:24-04:00
New Revision: e13c84c3be589c80edd2391664e136f54f0e3345

URL: https://github.com/llvm/llvm-project/commit/e13c84c3be589c80edd2391664e136f54f0e3345
DIFF: https://github.com/llvm/llvm-project/commit/e13c84c3be589c80edd2391664e136f54f0e3345.diff

LOG: GlobalISel: Work on improving stock set of legality predicates

I get confused by a lot of the predicate names here, since I would
assume they apply to vectors as well. Rename to reflect they only
apply to scalars.

Also add a few predicates AMDGPU uses that should be generally useful.
Also add any() to complement all. I've wanted to use this a few times
but then worked around it not being there.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
    llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
    llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index f913f5f41b8e..49bc66a89a21 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -203,6 +203,20 @@ template<typename Predicate, typename... Args>
 Predicate all(Predicate P0, Predicate P1, Args... args) {
   return all(all(P0, P1), args...);
 }
+
+/// True iff P0 or P1 are true.
+template<typename Predicate>
+Predicate any(Predicate P0, Predicate P1) {
+  return [=](const LegalityQuery &Query) {
+    return P0(Query) || P1(Query);
+  };
+}
+/// True iff any given predicates are true.
+template<typename Predicate, typename... Args>
+Predicate any(Predicate P0, Predicate P1, Args... args) {
+  return any(any(P0, P1), args...);
+}
+
 /// True iff the given type index is the specified types.
 LegalityPredicate typeIs(unsigned TypeIdx, LLT TypesInit);
 /// True iff the given type index is one of the specified types.
@@ -228,13 +242,16 @@ LegalityPredicate isPointer(unsigned TypeIdx);
 /// space.
 LegalityPredicate isPointer(unsigned TypeIdx, unsigned AddrSpace);
 
+/// True if the type index is a vector with element type \p EltTy
+LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT EltTy);
+
 /// True iff the specified type index is a scalar that's narrower than the given
 /// size.
-LegalityPredicate narrowerThan(unsigned TypeIdx, unsigned Size);
+LegalityPredicate scalarNarrowerThan(unsigned TypeIdx, unsigned Size);
 
 /// True iff the specified type index is a scalar that's wider than the given
 /// size.
-LegalityPredicate widerThan(unsigned TypeIdx, unsigned Size);
+LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
 
 /// True iff the specified type index is a scalar or vector with an element type
 /// that's narrower than the given size.
@@ -257,6 +274,15 @@ LegalityPredicate sizeIs(unsigned TypeIdx, unsigned Size);
 
 /// True iff the specified type indices are both the same bit size.
 LegalityPredicate sameSize(unsigned TypeIdx0, unsigned TypeIdx1);
+
+/// True iff the first type index has a larger total bit size than second type
+/// index.
+LegalityPredicate largerThan(unsigned TypeIdx0, unsigned TypeIdx1);
+
+/// True iff the first type index has a smaller total bit size than second type
+/// index.
+LegalityPredicate smallerThan(unsigned TypeIdx0, unsigned TypeIdx1);
+
 /// True iff the specified MMO index has a size that is not a power of 2
 LegalityPredicate memSizeInBytesNotPow2(unsigned MMOIdx);
 /// True iff the specified type index is a vector whose element count is not a
@@ -774,7 +800,7 @@ class LegalizeRuleSet {
     using namespace LegalityPredicates;
     using namespace LegalizeMutations;
     return actionIf(LegalizeAction::WidenScalar,
-                    narrowerThan(TypeIdx, Ty.getSizeInBits()),
+                    scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
                     changeTo(typeIdx(TypeIdx), Ty));
   }
 
@@ -792,7 +818,7 @@ class LegalizeRuleSet {
     using namespace LegalityPredicates;
     using namespace LegalizeMutations;
     return actionIf(LegalizeAction::NarrowScalar,
-                    widerThan(TypeIdx, Ty.getSizeInBits()),
+                    scalarWiderThan(TypeIdx, Ty.getSizeInBits()),
                     changeTo(typeIdx(TypeIdx), Ty));
   }
 
@@ -806,7 +832,7 @@ class LegalizeRuleSet {
     return actionIf(
         LegalizeAction::NarrowScalar,
         [=](const LegalityQuery &Query) {
-          return widerThan(TypeIdx, Ty.getSizeInBits()) && Predicate(Query);
+          return scalarWiderThan(TypeIdx, Ty.getSizeInBits()) && Predicate(Query);
         },
         changeElementTo(typeIdx(TypeIdx), Ty));
   }

diff  --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index b6fb061a8334..a83742f2138f 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -80,22 +80,46 @@ LegalityPredicate LegalityPredicates::isPointer(unsigned TypeIdx,
   };
 }
 
-LegalityPredicate LegalityPredicates::narrowerThan(unsigned TypeIdx,
-                                                   unsigned Size) {
+LegalityPredicate LegalityPredicates::elementTypeIs(unsigned TypeIdx,
+                                                    LLT EltTy) {
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isVector() && QueryTy.getElementType() == EltTy;
+  };
+}
+
+LegalityPredicate LegalityPredicates::scalarNarrowerThan(unsigned TypeIdx,
+                                                         unsigned Size) {
   return [=](const LegalityQuery &Query) {
     const LLT QueryTy = Query.Types[TypeIdx];
     return QueryTy.isScalar() && QueryTy.getSizeInBits() < Size;
   };
 }
 
-LegalityPredicate LegalityPredicates::widerThan(unsigned TypeIdx,
-                                                unsigned Size) {
+LegalityPredicate LegalityPredicates::scalarWiderThan(unsigned TypeIdx,
+                                                      unsigned Size) {
   return [=](const LegalityQuery &Query) {
     const LLT QueryTy = Query.Types[TypeIdx];
     return QueryTy.isScalar() && QueryTy.getSizeInBits() > Size;
   };
 }
 
+LegalityPredicate LegalityPredicates::smallerThan(unsigned TypeIdx0,
+                                                  unsigned TypeIdx1) {
+  return [=](const LegalityQuery &Query) {
+    return Query.Types[TypeIdx0].getSizeInBits() <
+           Query.Types[TypeIdx1].getSizeInBits();
+  };
+}
+
+LegalityPredicate LegalityPredicates::largerThan(unsigned TypeIdx0,
+                                                  unsigned TypeIdx1) {
+  return [=](const LegalityQuery &Query) {
+    return Query.Types[TypeIdx0].getSizeInBits() >
+           Query.Types[TypeIdx1].getSizeInBits();
+  };
+}
+
 LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
                                                               unsigned Size) {
   return [=](const LegalityQuery &Query) {

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 74e03e1d9919..2a546433a245 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -158,13 +158,6 @@ static LegalityPredicate isRegisterType(unsigned TypeIdx) {
   };
 }
 
-static LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT Type) {
-  return [=](const LegalityQuery &Query) {
-    const LLT QueryTy = Query.Types[TypeIdx];
-    return QueryTy.isVector() && QueryTy.getElementType() == Type;
-  };
-}
-
 static LegalityPredicate elementTypeIsLegal(unsigned TypeIdx) {
   return [=](const LegalityQuery &Query) {
     const LLT QueryTy = Query.Types[TypeIdx];
@@ -183,20 +176,6 @@ static LegalityPredicate isWideScalarTruncStore(unsigned TypeIdx) {
   };
 }
 
-static LegalityPredicate smallerThan(unsigned TypeIdx0, unsigned TypeIdx1) {
-  return [=](const LegalityQuery &Query) {
-    return Query.Types[TypeIdx0].getSizeInBits() <
-           Query.Types[TypeIdx1].getSizeInBits();
-  };
-}
-
-static LegalityPredicate greaterThan(unsigned TypeIdx0, unsigned TypeIdx1) {
-  return [=](const LegalityQuery &Query) {
-    return Query.Types[TypeIdx0].getSizeInBits() >
-           Query.Types[TypeIdx1].getSizeInBits();
-  };
-}
-
 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
                                          const GCNTargetMachine &TM)
   :  ST(ST_) {
@@ -680,7 +659,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
     // TODO: Should have same legality without v_perm_b32
     getActionDefinitionsBuilder(G_BSWAP)
       .legalFor({S32})
-      .lowerIf(narrowerThan(0, 32))
+      .lowerIf(scalarNarrowerThan(0, 32))
       // FIXME: Fixing non-power-of-2 before clamp is workaround for
       // narrowScalar limitation.
       .widenScalarToNextPow2(0)
@@ -707,7 +686,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
       [](const LegalityQuery &Query) {
         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
       })
-    .narrowScalarIf(greaterThan(1, 0),
+    .narrowScalarIf(largerThan(1, 0),
       [](const LegalityQuery &Query) {
         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
       });
@@ -724,7 +703,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
       })
     .narrowScalarIf(
-      greaterThan(0, 1),
+      largerThan(0, 1),
       [](const LegalityQuery &Query) {
         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
       });
@@ -1238,7 +1217,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
         })
       // Try to widen to s16 first for small types.
       // TODO: Only do this on targets with legal s16 shifts
-      .minScalarOrEltIf(narrowerThan(LitTyIdx, 16), LitTyIdx, S16)
+      .minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16)
       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
       .fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),


        


More information about the llvm-commits mailing list