[PATCH] D37888: [SCEV] Generalize folding of trunc(x)+n*trunc(y) into folding m*trunc(x)+n*trunc(y)

Daniel Neilson via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 09:28:03 PDT 2017


dneilson updated this revision to Diff 116013.
dneilson added a comment.

Minor rearrangement of the patch to get srctype from the lambda instead of a trunc.
( Sorry for the delay; I was out of town & away from computers. )


https://reviews.llvm.org/D37888

Files:
  lib/Analysis/ScalarEvolution.cpp
  unittests/Analysis/ScalarEvolutionTest.cpp


Index: unittests/Analysis/ScalarEvolutionTest.cpp
===================================================================
--- unittests/Analysis/ScalarEvolutionTest.cpp
+++ unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1009,5 +1009,37 @@
   auto Result = SE.createAddRecFromPHIWithCasts(cast<SCEVUnknown>(Expr));
 }
 
+TEST_F(ScalarEvolutionsTest, SCEVFoldSumOfTruncs) {
+  // Verify that the following SCEV gets folded to a zero:
+  //  (-1 * (trunc i64 (-1 * %0) to i32)) + (-1 * (trunc i64 %0 to i32)
+  Type *ArgTy = Type::getInt64Ty(Context);
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  SmallVector<Type *, 1> Types;
+  Types.push_back(ArgTy);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false);
+  Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
+  BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
+  ReturnInst::Create(Context, nullptr, BB);
+
+  ScalarEvolution SE = buildSE(*F);
+
+  auto *Arg = &*(F->arg_begin());
+  const auto *ArgSCEV = SE.getSCEV(Arg);
+
+  // Build the SCEV
+  const auto *A0 = SE.getNegativeSCEV(ArgSCEV);
+  const auto *A1 = SE.getTruncateExpr(A0, Int32Ty);
+  const auto *A = SE.getNegativeSCEV(A1);
+
+  const auto *B0 = SE.getTruncateExpr(ArgSCEV, Int32Ty);
+  const auto *B = SE.getNegativeSCEV(B0);
+
+  const auto *Expr = SE.getAddExpr(A, B);
+  dbgs() << "DDN\nExpr: " << *Expr << "\n";
+  // Verify that the SCEV was folded to 0
+  const auto *ZeroConst = SE.getConstant(Int32Ty, 0);
+  EXPECT_EQ(Expr, ZeroConst);
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm
Index: lib/Analysis/ScalarEvolution.cpp
===================================================================
--- lib/Analysis/ScalarEvolution.cpp
+++ lib/Analysis/ScalarEvolution.cpp
@@ -2338,12 +2338,34 @@
 
   // Check for truncates. If all the operands are truncated from the same
   // type, see if factoring out the truncate would permit the result to be
-  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
+  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
   // if the contents of the resulting outer trunc fold to something simple.
-  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
-    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
-    Type *DstType = Trunc->getType();
-    Type *SrcType = Trunc->getOperand()->getType();
+  auto FindTruncSrcType = [&]() -> Type * {
+    // Go through the available Ops to see if we have a compatible trunc() to
+    // start processing.
+    if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
+      return T->getOperand()->getType();
+    for (unsigned i = Idx; i < Ops.size() && Ops[i]->getSCEVType() <= scMulExpr;
+         ++i) {
+      if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[i])) {
+        bool Ok = true;
+        for (unsigned j = 0, e = Mul->getNumOperands(); Ok && j < e; ++j) {
+          const auto *Op = Mul->getOperand(j);
+          if (const auto *T = dyn_cast<SCEVTruncateExpr>(Op)) {
+            return T->getOperand()->getType();
+          } else if (!isa<SCEVConstant>(Op)) {
+            Ok = false;
+          }
+        }
+        if (!Ok)
+          break;
+      } else {
+        break;
+      }
+    }
+    return nullptr;
+  };
+  if (auto *SrcType = FindTruncSrcType()) {
     SmallVector<const SCEV *, 8> LargeOps;
     bool Ok = true;
     // Check all the operands to see if they can be represented in the
@@ -2386,7 +2408,7 @@
       const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
       // If it folds to something simple, use it. Otherwise, don't.
       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
-        return getTruncateExpr(Fold, DstType);
+        return getTruncateExpr(Fold, Ty);
     }
   }
 


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D37888.116013.patch
Type: text/x-patch
Size: 3778 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20170920/c1f66a6b/attachment.bin>


More information about the llvm-commits mailing list