[LLVMdev] Idea for optimization (test for remainder)

Jasper Neumann jn at sirrida.de
Sat Mar 8 11:51:18 PST 2014

```Hello Benjamin, hello folks!

[from the LLVM developer forum, comment corrected]
>> Consider the expression (x % d) == c where d and c are constants.
>> For simplicity let us assume that x is unsigned and 0 <= c < d.
>> Let us further assume that d = a * (1 << b) and a is odd.
>> Then our expression can be transformed to
>> rotate_right(x-c, b) * inverse_mul(a) <= (high_value(x) - c) / d .
>> Example [(x % 250) == 3]:
>>   sub eax,3
>>   ror eax,1
>>   imul eax,eax,0x26e978d5  // multiplicative inverse of 125
>>   cmp eax,17179869  // (0xffffffff-3) / 250
>>   jbe OK
>> [...]

> Yep, this is a long-standing issue in the peephole optimizer.
> It's not easily fixed because

> 1. We don't want to do it early (i.e. before codegen) because
>    the resulting expression is harder to analyze wrt. value range.
> 2. We can't do it late (in DAGCombiner) because it works top-down
>    and has already expanded the operation into the code you posted
>    above by the time it sees the compare.

Well, I tried a solution using the instruction combiner, and it turned
out well. I attach a working patch for unsigned values. The signed
version will come later if this patch is accepted.

How could I detect and include an additional range check which is
possible with the same amount of generated code?

By the way: Is there something like a floored or Euclidian
remainder/modulo operation (see
http://en.wikipedia.org/wiki/Modulo_operation)? How is it realized?

Best regards
Jasper
-------------- next part --------------
Index: lib/Transforms/InstCombine/InstCombineCompares.cpp
===================================================================
--- lib/Transforms/InstCombine/InstCombineCompares.cpp	(revision 203284)
+++ lib/Transforms/InstCombine/InstCombineCompares.cpp	(working copy)
@@ -795,6 +795,86 @@
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C));
}

+/// FoldICmpRemCst - Fold "icmp pred, ([su]rem X, RemRHS), CmpRHS" where RemRHS
+/// and CmpRHS are both known to be integer constants.
+/// Change a%d==c to rotr(a-c,b)*r<=u where m<<b==d, r*m==1, u==(high(a)-c)/d.
+Instruction *InstCombiner::FoldICmpRemCst(ICmpInst &ICI, BinaryOperator *RemI,
+                                          ConstantInt *RemRHS) {
+  if (!ICI.isEquality())  // We can only treat == and != operators.
+    return 0;
+
+  bool RemIsSigned = (RemI->getOpcode() == Instruction::SRem);
+
+  if (RemIsSigned)
+    return 0;  // TODO: Also catch signed variants (currently only unsigned).
+
+  if (RemIsSigned != ICI.isSigned())
+    return 0;  // Only treat same signed variants.
+  if (RemRHS->isZero())
+    return 0;  // We can not do anything useful with remainder by zero.
+  if (RemRHS->getValue().isPowerOf2())
+    return 0;  // This is handled as and/shift elsewhere.
+
+  ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1));
+  const APInt &CmpRHSV = CmpRHS->getValue();
+
+  APInt Factor = RemRHS->getValue();
+  if (CmpRHSV.uge(Factor))
+    return 0;  // The remainer is too large (always false).
+
+  Value *X = RemI->getOperand(0);
+  unsigned BitWidth = X->getType()->getScalarSizeInBits();
+  APInt One(BitWidth, 1);
+
+  // Split powers of two from Factor giving Rotate.
+  unsigned Rotate = 0;
+  while ((Factor & One) == 0) {
+    ++Rotate;
+    Factor = Factor.lshr(1);
+  }
+
+  // Invert Factor via multiplicative inverse.
+  APInt Mod = APInt::getSignedMinValue(BitWidth+1);
+  Factor = Factor.zext(BitWidth+1);
+  Factor = Factor.multiplicativeInverse(Mod);
+  Factor = Factor.trunc(BitWidth);
+
+  // Now build the transformed expression.
+  // rotr(X-CmpRHS, Rotate)*Factor <= HiBound
+  Value *Expr = X;
+
+  if (CmpRHSV != 0) {
+    Expr = Builder->CreateSub(Expr, CmpRHS);
+  }
+
+  if (Rotate) {
+    // Emulate missing rotate operation.
+    // Expr = Builder->CreateRotr(Expr, Rotate);
+    Value *Expr1 = Builder->CreateLShr(Expr, Rotate);
+    Value *Expr2 = Builder->CreateShl(Expr, BitWidth-Rotate);
+    Expr = Builder->CreateOr(Expr1, Expr2);
+  }
+
+  if (Factor != 1) {
+    Expr = Builder->CreateMul(Expr, ConstantInt::get(X->getType(), Factor));
+  }
+
+  APInt OpMax = APInt::getMaxValue(BitWidth);
+  Constant *HiBound = ConstantExpr::getUDiv(
+    ConstantInt::get(X->getType(), OpMax-CmpRHSV),
+    RemRHS);
+
+  // Get the ICmp opcode
+  ICmpInst::Predicate Pred = ICI.getPredicate();
+  switch (Pred) {
+  default: llvm_unreachable("Unhandled icmp opcode!");
+  case ICmpInst::ICMP_EQ:
+    return new ICmpInst(ICmpInst::ICMP_ULE, Expr, HiBound);
+  case ICmpInst::ICMP_NE:
+    return new ICmpInst(ICmpInst::ICMP_UGT, Expr, HiBound);
+  }
+}
+
/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS
/// and CmpRHS are both known to be integer constants.
Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
@@ -1531,6 +1611,20 @@
break;
}

+  case Instruction::SRem:
+  case Instruction::URem:
+    // Fold: icmp pred ([us]rem X, C1), C2 -> range test
+    // Fold this rem into the comparison, producing a range check.
+    // Determine, based on the remainder type, what the range is being
+    // checked.  If there is an overflow on the low or high side, remember
+    // it, otherwise compute the range [low, hi) bounding the new value.
+    // See: InsertRangeTest above for the kinds of replacements possible.
+    if (ConstantInt *RemRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1)))
+      if (Instruction *R = FoldICmpRemCst(ICI, cast<BinaryOperator>(LHSI),
+                                          RemRHS))
+        return R;
+    break;
+
case Instruction::SDiv:
case Instruction::UDiv:
// Fold: icmp pred ([us]div X, C1), C2 -> range test
Index: lib/Transforms/InstCombine/InstCombine.h
===================================================================
--- lib/Transforms/InstCombine/InstCombine.h	(revision 203284)
+++ lib/Transforms/InstCombine/InstCombine.h	(working copy)
@@ -163,6 +163,8 @@
Instruction *visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
Instruction *LHS,
ConstantInt *RHS);
+  Instruction *FoldICmpRemCst(ICmpInst &ICI, BinaryOperator *RemI,
+                              ConstantInt *RemRHS);
Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
ConstantInt *DivRHS);
Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI,
```