[llvm-commits] [llvm] r171812 - in /llvm/trunk: lib/Transforms/Vectorize/LoopVectorize.cpp test/Transforms/LoopVectorize/float-reduction.ll

Nadav Rotem nrotem at apple.com
Mon Jan 7 15:13:01 PST 2013


Author: nadav
Date: Mon Jan  7 17:13:00 2013
New Revision: 171812

URL: http://llvm.org/viewvc/llvm-project?rev=171812&view=rev
Log:
LoopVectorizer: Add support for floating point reductions

Added:
    llvm/trunk/test/Transforms/LoopVectorize/float-reduction.ll
Modified:
    llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp

Modified: llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp?rev=171812&r1=171811&r2=171812&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp Mon Jan  7 17:13:00 2013
@@ -215,10 +215,6 @@
   /// broadcast them into a vector.
   VectorParts &getVectorValue(Value *V);
 
-  /// Get a uniform vector of constant integers. We use this to get
-  /// vectors of ones and zeros for the reduction code.
-  Constant* getUniformVector(unsigned Val, Type* ScalarTy);
-
   /// Generate a shuffle sequence that will reverse the vector Vec.
   Value *reverseVector(Value *Vec);
 
@@ -325,12 +321,14 @@
 
   /// This enum represents the kinds of reductions that we support.
   enum ReductionKind {
-    NoReduction, ///< Not a reduction.
-    IntegerAdd,  ///< Sum of numbers.
-    IntegerMult, ///< Product of numbers.
-    IntegerOr,   ///< Bitwise or logical OR of numbers.
-    IntegerAnd,  ///< Bitwise or logical AND of numbers.
-    IntegerXor   ///< Bitwise or logical XOR of numbers.
+    RK_NoReduction, ///< Not a reduction.
+    RK_IntegerAdd,  ///< Sum of integers.
+    RK_IntegerMult, ///< Product of integers.
+    RK_IntegerOr,   ///< Bitwise or logical OR of numbers.
+    RK_IntegerAnd,  ///< Bitwise or logical AND of numbers.
+    RK_IntegerXor,  ///< Bitwise or logical XOR of numbers.
+    RK_FloatAdd,    ///< Sum of floats.
+    RK_FloatMult    ///< Product of floats.
   };
 
   /// This enum represents the kinds of inductions that we support.
@@ -343,8 +341,8 @@
 
   /// This POD struct holds information about reduction variables.
   struct ReductionDescriptor {
-    ReductionDescriptor() : StartValue(0), LoopExitInstr(0), Kind(NoReduction) {
-    }
+    ReductionDescriptor() : StartValue(0), LoopExitInstr(0),
+      Kind(RK_NoReduction) {}
 
     ReductionDescriptor(Value *Start, Instruction *Exit, ReductionKind K)
         : StartValue(Start), LoopExitInstr(Exit), Kind(K) {}
@@ -790,11 +788,6 @@
   return WidenMap.get(V);
 }
 
-Constant*
-InnerLoopVectorizer::getUniformVector(unsigned Val, Type* ScalarTy) {
-  return ConstantVector::getSplat(VF, ConstantInt::get(ScalarTy, Val, true));
-}
-
 Value *InnerLoopVectorizer::reverseVector(Value *Vec) {
   assert(Vec->getType()->isVectorTy() && "Invalid type");
   SmallVector<Constant*, 8> ShuffleMask;
@@ -1215,20 +1208,26 @@
 
 /// This function returns the identity element (or neutral element) for
 /// the operation K.
-static unsigned
-getReductionIdentity(LoopVectorizationLegality::ReductionKind K) {
+static Constant*
+getReductionIdentity(LoopVectorizationLegality::ReductionKind K, Type *Tp) {
   switch (K) {
-  case LoopVectorizationLegality::IntegerXor:
-  case LoopVectorizationLegality::IntegerAdd:
-  case LoopVectorizationLegality::IntegerOr:
+  case LoopVectorizationLegality:: RK_IntegerXor:
+  case LoopVectorizationLegality:: RK_IntegerAdd:
+  case LoopVectorizationLegality:: RK_IntegerOr:
     // Adding, Xoring, Oring zero to a number does not change it.
-    return 0;
-  case LoopVectorizationLegality::IntegerMult:
+    return ConstantInt::get(Tp, 0);
+  case LoopVectorizationLegality:: RK_IntegerMult:
     // Multiplying a number by 1 does not change it.
-    return 1;
-  case LoopVectorizationLegality::IntegerAnd:
+    return ConstantInt::get(Tp, 1);
+  case LoopVectorizationLegality:: RK_IntegerAnd:
     // AND-ing a number with an all-1 value does not change it.
-    return -1;
+    return ConstantInt::get(Tp, -1, true);
+  case LoopVectorizationLegality:: RK_FloatMult:
+    // Multiplying a number by 1 does not change it.
+    return ConstantFP::get(Tp, 1.0L);
+  case LoopVectorizationLegality:: RK_FloatAdd:
+    // Adding zero to a number does not change it.
+    return ConstantFP::get(Tp, 0.0L);
   default:
     llvm_unreachable("Unknown reduction kind");
   }
@@ -1329,8 +1328,8 @@
 
     // Find the reduction identity variable. Zero for addition, or, xor,
     // one for multiplication, -1 for And.
-    Constant *Identity = getUniformVector(getReductionIdentity(RdxDesc.Kind),
-                                          VecTy->getScalarType());
+    Constant *Iden = getReductionIdentity(RdxDesc.Kind, VecTy->getScalarType());
+    Constant *Identity = ConstantVector::getSplat(VF, Iden);
 
     // This vector is the Identity vector where the first element is the
     // incoming scalar reduction.
@@ -1378,26 +1377,34 @@
     Value *ReducedPartRdx = RdxParts[0];
     for (unsigned part = 1; part < UF; ++part) {
       switch (RdxDesc.Kind) {
-      case LoopVectorizationLegality::IntegerAdd:
+      case LoopVectorizationLegality::RK_IntegerAdd:
         ReducedPartRdx = 
           Builder.CreateAdd(RdxParts[part], ReducedPartRdx, "add.rdx");
         break;
-      case LoopVectorizationLegality::IntegerMult:
+      case LoopVectorizationLegality::RK_IntegerMult:
         ReducedPartRdx =
           Builder.CreateMul(RdxParts[part], ReducedPartRdx, "mul.rdx");
         break;
-      case LoopVectorizationLegality::IntegerOr:
+      case LoopVectorizationLegality::RK_IntegerOr:
         ReducedPartRdx =
           Builder.CreateOr(RdxParts[part], ReducedPartRdx, "or.rdx");
         break;
-      case LoopVectorizationLegality::IntegerAnd:
+      case LoopVectorizationLegality::RK_IntegerAnd:
         ReducedPartRdx =
           Builder.CreateAnd(RdxParts[part], ReducedPartRdx, "and.rdx");
         break;
-      case LoopVectorizationLegality::IntegerXor:
+      case LoopVectorizationLegality::RK_IntegerXor:
         ReducedPartRdx =
           Builder.CreateXor(RdxParts[part], ReducedPartRdx, "xor.rdx");
         break;
+      case LoopVectorizationLegality::RK_FloatMult:
+        ReducedPartRdx =
+          Builder.CreateFMul(RdxParts[part], ReducedPartRdx, "fmul.rdx");
+        break;
+      case LoopVectorizationLegality::RK_FloatAdd:
+        ReducedPartRdx =
+          Builder.CreateFAdd(RdxParts[part], ReducedPartRdx, "fadd.rdx");
+        break;
       default:
         llvm_unreachable("Unknown reduction operation");
       }
@@ -1428,21 +1435,27 @@
 
       // Emit the operation on the shuffled value.
       switch (RdxDesc.Kind) {
-      case LoopVectorizationLegality::IntegerAdd:
+      case LoopVectorizationLegality::RK_IntegerAdd:
         TmpVec = Builder.CreateAdd(TmpVec, Shuf, "add.rdx");
         break;
-      case LoopVectorizationLegality::IntegerMult:
+      case LoopVectorizationLegality::RK_IntegerMult:
         TmpVec = Builder.CreateMul(TmpVec, Shuf, "mul.rdx");
         break;
-      case LoopVectorizationLegality::IntegerOr:
+      case LoopVectorizationLegality::RK_IntegerOr:
         TmpVec = Builder.CreateOr(TmpVec, Shuf, "or.rdx");
         break;
-      case LoopVectorizationLegality::IntegerAnd:
+      case LoopVectorizationLegality::RK_IntegerAnd:
         TmpVec = Builder.CreateAnd(TmpVec, Shuf, "and.rdx");
         break;
-      case LoopVectorizationLegality::IntegerXor:
+      case LoopVectorizationLegality::RK_IntegerXor:
         TmpVec = Builder.CreateXor(TmpVec, Shuf, "xor.rdx");
         break;
+      case LoopVectorizationLegality::RK_FloatMult:
+        TmpVec = Builder.CreateFMul(TmpVec, Shuf, "fmul.rdx");
+        break;
+      case LoopVectorizationLegality::RK_FloatAdd:
+        TmpVec = Builder.CreateFAdd(TmpVec, Shuf, "fadd.rdx");
+        break;
       default:
         llvm_unreachable("Unknown reduction operation");
       }
@@ -2074,6 +2087,7 @@
 
         // Check that this PHI type is allowed.
         if (!Phi->getType()->isIntegerTy() &&
+            !Phi->getType()->isFloatingPointTy() &&
             !Phi->getType()->isPointerTy()) {
           DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n");
           return false;
@@ -2105,26 +2119,34 @@
           continue;
         }
 
-        if (AddReductionVar(Phi, IntegerAdd)) {
+        if (AddReductionVar(Phi, RK_IntegerAdd)) {
           DEBUG(dbgs() << "LV: Found an ADD reduction PHI."<< *Phi <<"\n");
           continue;
         }
-        if (AddReductionVar(Phi, IntegerMult)) {
+        if (AddReductionVar(Phi, RK_IntegerMult)) {
           DEBUG(dbgs() << "LV: Found a MUL reduction PHI."<< *Phi <<"\n");
           continue;
         }
-        if (AddReductionVar(Phi, IntegerOr)) {
+        if (AddReductionVar(Phi, RK_IntegerOr)) {
           DEBUG(dbgs() << "LV: Found an OR reduction PHI."<< *Phi <<"\n");
           continue;
         }
-        if (AddReductionVar(Phi, IntegerAnd)) {
+        if (AddReductionVar(Phi, RK_IntegerAnd)) {
           DEBUG(dbgs() << "LV: Found an AND reduction PHI."<< *Phi <<"\n");
           continue;
         }
-        if (AddReductionVar(Phi, IntegerXor)) {
+        if (AddReductionVar(Phi, RK_IntegerXor)) {
           DEBUG(dbgs() << "LV: Found a XOR reduction PHI."<< *Phi <<"\n");
           continue;
         }
+        if (AddReductionVar(Phi, RK_FloatMult)) {
+          DEBUG(dbgs() << "LV: Found an FMult reduction PHI."<< *Phi <<"\n");
+          continue;
+        }
+        if (AddReductionVar(Phi, RK_FloatAdd)) {
+          DEBUG(dbgs() << "LV: Found an FAdd reduction PHI."<< *Phi <<"\n");
+          continue;
+        }
 
         DEBUG(dbgs() << "LV: Found an unidentified PHI."<< *Phi <<"\n");
         return false;
@@ -2419,6 +2441,8 @@
   // This includes users of the reduction, variables (which form a cycle
   // which ends in the phi node).
   Instruction *ExitInstruction = 0;
+  // Indicates that we found a binary operation in our scan.
+  bool FoundBinOp = false;
 
   // Iter is our iterator. We start with the PHI node and scan for all of the
   // users of this instruction. All users must be instructions that can be
@@ -2436,6 +2460,9 @@
     // Did we reach the initial PHI node already ?
     bool FoundStartPHI = false;
 
+    // Is this a bin op ?
+    FoundBinOp |= !isa<PHINode>(Iter);
+
     // For each of the *users* of iter.
     for (Value::use_iterator it = Iter->use_begin(), e = Iter->use_end();
          it != e; ++it) {
@@ -2475,7 +2502,7 @@
 
       // Reductions of instructions such as Div, and Sub is only
       // possible if the LHS is the reduction variable.
-      if (!U->isCommutative() && U->getOperand(0) != Iter)
+      if (!U->isCommutative() && !isa<PHINode>(U) && U->getOperand(0) != Iter)
         return false;
 
       Iter = U;
@@ -2484,46 +2511,52 @@
     // We found a reduction var if we have reached the original
     // phi node and we only have a single instruction with out-of-loop
     // users.
-    if (FoundStartPHI && ExitInstruction) {
+    if (FoundStartPHI) {
       // This instruction is allowed to have out-of-loop users.
       AllowedExit.insert(ExitInstruction);
 
       // Save the description of this reduction variable.
       ReductionDescriptor RD(RdxStart, ExitInstruction, Kind);
       Reductions[Phi] = RD;
-      return true;
+      // We've ended the cycle. This is a reduction variable if we have an
+      // outside user and it has a binary op.
+      return FoundBinOp && ExitInstruction;
     }
-
-    // If we've reached the start PHI but did not find an outside user then
-    // this is dead code. Abort.
-    if (FoundStartPHI)
-      return false;
   }
 }
 
 bool
 LoopVectorizationLegality::isReductionInstr(Instruction *I,
                                             ReductionKind Kind) {
+  bool FP = I->getType()->isFloatingPointTy();
+  bool FastMath = (FP && I->isCommutative() && I->isAssociative());
+
   switch (I->getOpcode()) {
   default:
     return false;
   case Instruction::PHI:
+      if (FP && (Kind != RK_FloatMult && Kind != RK_FloatAdd))
+        return false;
     // possibly.
     return true;
   case Instruction::Sub:
   case Instruction::Add:
-    return Kind == IntegerAdd;
+    return Kind == RK_IntegerAdd;
   case Instruction::SDiv:
   case Instruction::UDiv:
   case Instruction::Mul:
-    return Kind == IntegerMult;
+    return Kind == RK_IntegerMult;
   case Instruction::And:
-    return Kind == IntegerAnd;
+    return Kind == RK_IntegerAnd;
   case Instruction::Or:
-    return Kind == IntegerOr;
+    return Kind == RK_IntegerOr;
   case Instruction::Xor:
-    return Kind == IntegerXor;
-  }
+    return Kind == RK_IntegerXor;
+  case Instruction::FMul:
+    return Kind == RK_FloatMult && FastMath;
+  case Instruction::FAdd:
+    return Kind == RK_FloatAdd && FastMath;
+   }
 }
 
 LoopVectorizationLegality::InductionKind

Added: llvm/trunk/test/Transforms/LoopVectorize/float-reduction.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopVectorize/float-reduction.ll?rev=171812&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/LoopVectorize/float-reduction.ll (added)
+++ llvm/trunk/test/Transforms/LoopVectorize/float-reduction.ll Mon Jan  7 17:13:00 2013
@@ -0,0 +1,29 @@
+; RUN: opt < %s  -loop-vectorize -force-vector-unroll=1 -force-vector-width=4 -dce -instcombine -S | FileCheck %s
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-apple-macosx10.8.0"
+;CHECK: @foo
+;CHECK: fadd <4 x float>
+;CHECK: ret
+define float @foo(float* nocapture %A, i32* nocapture %n) nounwind uwtable readonly ssp {
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %sum.04 = phi float [ 0.000000e+00, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds float* %A, i64 %indvars.iv
+  %0 = load float* %arrayidx, align 4, !tbaa !0
+  %add = fadd fast float %sum.04, %0
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, 200
+  br i1 %exitcond, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.body
+  ret float %add
+}
+
+!0 = metadata !{metadata !"float", metadata !1}
+!1 = metadata !{metadata !"omnipotent char", metadata !2}
+!2 = metadata !{metadata !"Simple C/C++ TBAA"}





More information about the llvm-commits mailing list