[llvm] 593e25f - [Vectorize] Fix vectorization, scalarization and folding of llvm.is.fpclass

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 24 05:42:22 PDT 2023


Author: Jay Foad
Date: 2023-04-24T13:42:08+01:00
New Revision: 593e25ffae4a1d67acf61c485220886ffbb0b3e7

URL: https://github.com/llvm/llvm-project/commit/593e25ffae4a1d67acf61c485220886ffbb0b3e7
DIFF: https://github.com/llvm/llvm-project/commit/593e25ffae4a1d67acf61c485220886ffbb0b3e7.diff

LOG: [Vectorize] Fix vectorization, scalarization and folding of llvm.is.fpclass

llvm.is.fpclass is different from other vectorizable intrinsics in that
it is overloaded on an argument type, not on the return type.

Differential Revision: https://reviews.llvm.org/D148905

Added: 
    llvm/test/Transforms/InstSimplify/is_fpclass.ll

Modified: 
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/Transforms/Scalar/Scalarizer.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
    llvm/test/Transforms/LoopVectorize/is_fpclass.ll
    llvm/test/Transforms/SLPVectorizer/is_fpclass.ll
    llvm/test/Transforms/Scalarizer/intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 95fdff81c83fc..5d824540c3896 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -347,9 +347,9 @@ bool isTriviallyVectorizable(Intrinsic::ID ID);
 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                         unsigned ScalarOpdIdx);
 
-/// Identifies if the vector form of the intrinsic has a operand that has
-/// an overloaded type.
-bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, unsigned OpdIdx);
+/// Identifies if the vector form of the intrinsic is overloaded on the type of
+/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
+bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
 
 /// Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 7700c722765d3..aadecfb9e8180 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -86,6 +86,7 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
   case Intrinsic::pow:
   case Intrinsic::fma:
   case Intrinsic::fmuladd:
+  case Intrinsic::is_fpclass:
   case Intrinsic::powi:
   case Intrinsic::canonicalize:
   case Intrinsic::fptosi_sat:
@@ -103,6 +104,7 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   case Intrinsic::abs:
   case Intrinsic::ctlz:
   case Intrinsic::cttz:
+  case Intrinsic::is_fpclass:
   case Intrinsic::powi:
     return (ScalarOpdIdx == 1);
   case Intrinsic::smul_fix:
@@ -116,15 +118,17 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 }
 
 bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                  unsigned OpdIdx) {
+                                                  int OpdIdx) {
   switch (ID) {
   case Intrinsic::fptosi_sat:
   case Intrinsic::fptoui_sat:
+    return OpdIdx == -1 || OpdIdx == 0;
+  case Intrinsic::is_fpclass:
     return OpdIdx == 0;
   case Intrinsic::powi:
-    return OpdIdx == 1;
+    return OpdIdx == -1 || OpdIdx == 1;
   default:
-    return false;
+    return OpdIdx == -1;
   }
 }
 

diff  --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 4aab88b74f106..30aff08629951 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -583,7 +583,9 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
   Scattered.resize(NumArgs);
 
   SmallVector<llvm::Type *, 3> Tys;
-  Tys.push_back(VT->getScalarType());
+  // Add return type if intrinsic is overloaded on it.
+  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+    Tys.push_back(VT->getScalarType());
 
   // Assumes that any vector type has the same number of elements as the return
   // vector type, which is true for all current intrinsics.

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index bd59924cf2706..8b6fa6883163c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10335,8 +10335,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
 
       Value *ScalarArg = nullptr;
       std::vector<Value *> OpVecs;
-      SmallVector<Type *, 2> TysForDecl =
-          {FixedVectorType::get(CI->getType(), E->Scalars.size())};
+      SmallVector<Type *, 2> TysForDecl;
+      // Add return type if intrinsic is overloaded on it.
+      if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+        TysForDecl.push_back(
+            FixedVectorType::get(CI->getType(), E->Scalars.size()));
       for (int j = 0, e = CI->arg_size(); j < e; ++j) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index cb14e33a4c588..41335fb8de3b8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -481,7 +481,14 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
   State.setDebugLocFromInst(&CI);
 
   for (unsigned Part = 0; Part < State.UF; ++Part) {
-    SmallVector<Type *, 2> TysForDecl = {CI.getType()};
+    SmallVector<Type *, 2> TysForDecl;
+    // Add return type if intrinsic is overloaded on it.
+    if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) {
+      TysForDecl.push_back(
+          State.VF.isVector()
+              ? VectorType::get(CI.getType()->getScalarType(), State.VF)
+              : CI.getType());
+    }
     SmallVector<Value *, 4> Args;
     for (const auto &I : enumerate(operands())) {
       // Some intrinsics have a scalar argument - don't replace it with a
@@ -500,9 +507,6 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
     Function *VectorF;
     if (VectorIntrinsicID != Intrinsic::not_intrinsic) {
       // Use vector version of the intrinsic.
-      if (State.VF.isVector())
-        TysForDecl[0] =
-            VectorType::get(CI.getType()->getScalarType(), State.VF);
       Module *M = State.Builder.GetInsertBlock()->getModule();
       VectorF = Intrinsic::getDeclaration(M, VectorIntrinsicID, TysForDecl);
       assert(VectorF && "Can't retrieve vector intrinsic.");

diff  --git a/llvm/test/Transforms/InstSimplify/is_fpclass.ll b/llvm/test/Transforms/InstSimplify/is_fpclass.ll
new file mode 100644
index 0000000000000..b14bfcbbfaac3
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/is_fpclass.ll
@@ -0,0 +1,12 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -S -passes=instsimplify | FileCheck %s
+
+define <2 x i1> @f() {
+; CHECK-LABEL: define <2 x i1> @f() {
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+;
+  %i = call <2 x i1> @llvm.is.fpclass.v2f16(<2 x half> <half 0xH7C00, half 0xH7C00>, i32 3)
+  ret <2 x i1> %i
+}
+
+declare <2 x i1> @llvm.is.fpclass.v2f16(<2 x half>, i32 immarg)

diff  --git a/llvm/test/Transforms/LoopVectorize/is_fpclass.ll b/llvm/test/Transforms/LoopVectorize/is_fpclass.ll
index 955b1b862f149..c711c730ad2ac 100644
--- a/llvm/test/Transforms/LoopVectorize/is_fpclass.ll
+++ b/llvm/test/Transforms/LoopVectorize/is_fpclass.ll
@@ -4,9 +4,28 @@
 define void @d() {
 ; CHECK-LABEL: define void @d() {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr float, ptr @d, i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> zeroinitializer, i32 0)
+; CHECK-NEXT:    [[TMP3:%.*]] = select <2 x i1> [[TMP2]], <2 x float> zeroinitializer, <2 x float> zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr float, ptr [[TMP1]], i32 0
+; CHECK-NEXT:    store <2 x float> [[TMP3]], ptr [[TMP4]], align 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, 0
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ 0, [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[I7:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[I7:%.*]], [[LOOP]] ]
 ; CHECK-NEXT:    [[I3:%.*]] = load float, ptr null, align 4
 ; CHECK-NEXT:    [[I4:%.*]] = getelementptr float, ptr @d, i64 [[I]]
 ; CHECK-NEXT:    [[I5:%.*]] = tail call i1 @llvm.is.fpclass.f32(float 0.000000e+00, i32 0)
@@ -14,7 +33,7 @@ define void @d() {
 ; CHECK-NEXT:    store float [[I6]], ptr [[I4]], align 4
 ; CHECK-NEXT:    [[I7]] = add i64 [[I]], 1
 ; CHECK-NEXT:    [[I8:%.*]] = icmp eq i64 [[I7]], 0
-; CHECK-NEXT:    br i1 [[I8]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK-NEXT:    br i1 [[I8]], label [[EXIT]], label [[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
 ;

diff  --git a/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll b/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll
index 2bb5c347b8152..d76f9f5164f3e 100644
--- a/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll
+++ b/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll
@@ -4,13 +4,8 @@
 define <2 x i1> @scalarize_is_fpclass(<2 x float> %x) {
 ; CHECK-LABEL: define <2 x i1> @scalarize_is_fpclass
 ; CHECK-SAME: (<2 x float> [[X:%.*]]) {
-; CHECK-NEXT:    [[X_I0:%.*]] = extractelement <2 x float> [[X]], i32 0
-; CHECK-NEXT:    [[ISFPCLASS_I0:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I0]], i32 123)
-; CHECK-NEXT:    [[X_I1:%.*]] = extractelement <2 x float> [[X]], i32 1
-; CHECK-NEXT:    [[ISFPCLASS_I1:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I1]], i32 123)
-; CHECK-NEXT:    [[ISFPCLASS_UPTO0:%.*]] = insertelement <2 x i1> poison, i1 [[ISFPCLASS_I0]], i32 0
-; CHECK-NEXT:    [[ISFPCLASS:%.*]] = insertelement <2 x i1> [[ISFPCLASS_UPTO0]], i1 [[ISFPCLASS_I1]], i32 1
-; CHECK-NEXT:    ret <2 x i1> [[ISFPCLASS]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> [[X]], i32 123)
+; CHECK-NEXT:    ret <2 x i1> [[TMP1]]
 ;
   %x.i0 = extractelement <2 x float> %x, i32 0
   %isfpclass.i0 = call i1 @llvm.is.fpclass.f32(float %x.i0, i32 123)

diff  --git a/llvm/test/Transforms/Scalarizer/intrinsics.ll b/llvm/test/Transforms/Scalarizer/intrinsics.ll
index 3e1c03c8a382b..69193352ea641 100644
--- a/llvm/test/Transforms/Scalarizer/intrinsics.ll
+++ b/llvm/test/Transforms/Scalarizer/intrinsics.ll
@@ -212,7 +212,12 @@ define <2 x i32> @scalarize_fptoui_sat(<2 x float> %x) #0 {
 
 define <2 x i1> @scalarize_is_fpclass(<2 x float> %x) #0 {
 ; CHECK-LABEL: @scalarize_is_fpclass(
-; CHECK-NEXT:    [[ISFPCLASS:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> [[X:%.*]], i32 123)
+; CHECK-NEXT:    [[X_I0:%.*]] = extractelement <2 x float> [[X:%.*]], i32 0
+; CHECK-NEXT:    [[ISFPCLASS_I0:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I0]], i32 123)
+; CHECK-NEXT:    [[X_I1:%.*]] = extractelement <2 x float> [[X]], i32 1
+; CHECK-NEXT:    [[ISFPCLASS_I1:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I1]], i32 123)
+; CHECK-NEXT:    [[ISFPCLASS_UPTO0:%.*]] = insertelement <2 x i1> poison, i1 [[ISFPCLASS_I0]], i32 0
+; CHECK-NEXT:    [[ISFPCLASS:%.*]] = insertelement <2 x i1> [[ISFPCLASS_UPTO0]], i1 [[ISFPCLASS_I1]], i32 1
 ; CHECK-NEXT:    ret <2 x i1> [[ISFPCLASS]]
 ;
   %isfpclass = call <2 x i1> @llvm.is.fpclass(<2 x float> %x, i32 123)


        


More information about the llvm-commits mailing list