[llvm] [AggressiveInstCombine] Match long high-half multiply (PR #168396)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 17 08:21:08 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

This patch adds recognition of high-half multiply by parts into a single larger
multiply.

Considering a multiply made up of high and low parts, we can split the
multiply into:
```
 x * y == (xh*T + xl) * (yh*T + yl)
```
where `xh == x>>32` and `xl == x & 0xffffffff`. `T = 2^32`.
This expands to
```
 xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
```
which I find it helpful to be drawn as
```
[  xh*yh  ]
     [  xh*yl  ]
     [  xl*yh  ]
          [  xl*yl  ]
```

We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:
```
 Carry:  xh*yh + carry + lowsum
         carry = lowsum < xh*yl ? 0x1000000 : 0
         lowsum = xh*yl + xl*yh + (xl*yl>>32)
 Ladder: xh*yh + c2>>32 + c3>>32
         c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
 Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
         crosssum = xh*yl + xl*yh
         carry = crosssum < xh*yl ? 0x1000000 : 0
 Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
         low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
```

They all start by matching `xh*yh` + 2 or 3 other operands. The bottom of the
tree is `xh*yh`, `xh*yl`, `xl*yh` and `xl*yl`.

Based on https://github.com/llvm/llvm-project/pull/156879 by @<!-- -->c-rhodes

---

Patch is 226.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168396.diff


5 Files Affected:

- (modified) llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (+301) 
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll (+755) 
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry4.ll (+3019) 
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder.ll (+818) 
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder4.ll (+530) 


``````````diff
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index b575d76e897d2..fb71f57eaa502 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
   return false;
 }
 
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+///  x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+///  xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [  xh*yh  ]
+///      [  xh*yl  ]
+///      [  xl*yh  ]
+///           [  xl*yl  ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+///  Carry:  xh*yh + carry + lowsum
+///          carry = lowsum < xh*yl ? 0x1000000 : 0
+///          lowsum = xh*yl + xl*yh + (xl*yl>>32)
+///  Ladder: xh*yh + c2>>32 + c3>>32
+///          c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+///  Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+///          crosssum = xh*yl + xl*yh
+///          carry = crosssum < xh*yl ? 0x1000000 : 0
+///  Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+///          low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+  Type *Ty = I.getType();
+  if (!Ty->isIntOrIntVectorTy())
+    return false;
+
+  unsigned BW = Ty->getScalarSizeInBits();
+  APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+  if (BW % 2 != 0)
+    return false;
+
+  auto CreateMulHigh = [&](Value *X, Value *Y) {
+    IRBuilder<> Builder(&I);
+    Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+    Value *XExt = Builder.CreateZExt(X, NTy);
+    Value *YExt = Builder.CreateZExt(Y, NTy);
+    Value *Mul = Builder.CreateMul(XExt, YExt);
+    Value *High = Builder.CreateLShr(Mul, BW);
+    Value *Res = Builder.CreateTrunc(High, Ty);
+    I.replaceAllUsesWith(Res);
+    LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+                      << *Y << "\n");
+    return true;
+  };
+
+  // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+  auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+    return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+  auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+    return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+
+  auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+                              Instruction *B) {
+    // Looking for LowSum >> 32 and carry (select)
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, B);
+
+    // Carry = LowSum < XhYl ? 0x100000000 : 0
+    CmpPredicate Pred;
+    Value *LowSum, *XhYl;
+    if (!match(Carry,
+               m_OneUse(m_Select(
+                   m_OneUse(m_ICmp(Pred, m_Value(LowSum), m_Value(XhYl))),
+                   m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+        Pred != ICmpInst::ICMP_ULT)
+      return false;
+
+    // XhYl can be Xh*Yl or Xl*Yh
+    if (!CheckHiLo(XhYl, X, Y)) {
+      if (CheckHiLo(XhYl, Y, X))
+        std::swap(X, Y);
+      else
+        return false;
+    }
+    if (XhYl->hasNUsesOrMore(3))
+      return false;
+
+    // B = LowSum >> 16
+    if (!match(B,
+               m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
+        LowSum->hasNUsesOrMore(3))
+      return false;
+
+    // LowSum = XhYl + XlYh + XlYl>>32
+    Value *XlYh, *XlYl;
+    auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
+    if (!match(LowSum,
+               m_c_Add(m_Specific(XhYl),
+                       m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
+        !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
+                               m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
+        !match(LowSum,
+               m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
+                                                m_OneUse(m_Value(XlYh)))))))
+      return false;
+
+    // Check XlYl and XlYh
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+    if (!CheckHiLo(XlYh, Y, X))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
+                               Instruction *B) {
+    //  xh*yh + c2>>32 + c3>>32
+    //  c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
+    Value *XlYh, *XhYl, *C2, *C3;
+    // Strip off the two expected shifts.
+    if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
+        !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
+      return false;
+
+    // Match c3 = c2&0xffffffff + xl*yh
+    if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
+                           m_Value(XhYl))))
+      std::swap(C2, C3);
+    if (!match(C3,
+               m_c_Add(m_OneUse(m_And(m_Specific(C2), m_SpecificInt(LowMask))),
+                       m_Value(XhYl))) ||
+        !C3->hasOneUse() || C2->hasNUsesOrMore(3))
+      return false;
+
+    // Match c2 = xh*yl + (xl*yl >> 32)
+    Value *XlYl;
+    if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
+                           m_Value(XlYh))))
+      return false;
+
+    // Match XhYl and XlYh - they can appear either way around.
+    if (!CheckHiLo(XlYh, Y, X))
+      std::swap(XlYh, XhYl);
+    if (!CheckHiLo(XlYh, Y, X))
+      return false;
+    if (!CheckHiLo(XhYl, X, Y))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
+                                Instruction *B, Instruction *C) {
+    ///  Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
+    ///           low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+
+    // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
+    auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2));
+    if (!match(A, ShiftAdd))
+      std::swap(A, B);
+    if (!match(A, ShiftAdd))
+      std::swap(A, C);
+    Value *Low;
+    if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2))))
+      return false;
+
+    // Match B == XhYl>>32 and C == XlYh>>32
+    Value *XhYl, *XlYh;
+    if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) ||
+        !match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2))))
+      return false;
+    if (!CheckHiLo(XhYl, X, Y))
+      std::swap(XhYl, XlYh);
+    if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
+      return false;
+    if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
+      return false;
+
+    // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
+    Value *XlYl;
+    if (!match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+                    m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
+                m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) &&
+        !match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+                    m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+                m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
+        !match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
+                    m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+                m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
+                               Instruction *B, Instruction *C) {
+    //  xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+    //  crosssum = xh*yl+xl*yh
+    //  carry = crosssum < xh*yl ? 0x1000000 : 0
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, B);
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, C);
+
+    // Carry = CrossSum < XhYl ? 0x100000000 : 0
+    CmpPredicate Pred;
+    Value *CrossSum, *XhYl;
+    if (!match(Carry,
+               m_OneUse(m_Select(
+                   m_OneUse(m_ICmp(Pred, m_Value(CrossSum), m_Value(XhYl))),
+                   m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+        Pred != ICmpInst::ICMP_ULT)
+      return false;
+
+    if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+      std::swap(B, C);
+    if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+      return false;
+
+    Value *XlYl, *LowAccum;
+    if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) ||
+        !match(LowAccum,
+               m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
+                       m_OneUse(m_And(m_Specific(CrossSum),
+                                      m_SpecificInt(LowMask))))) ||
+        LowAccum->hasNUsesOrMore(3))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    if (!CheckHiLo(XhYl, X, Y))
+      std::swap(X, Y);
+    if (!CheckHiLo(XhYl, X, Y))
+      return false;
+    if (!match(CrossSum,
+               m_c_Add(m_Specific(XhYl),
+                       m_OneUse(m_c_Mul(
+                           m_LShr(m_Specific(Y), m_SpecificInt(BW / 2)),
+                           m_And(m_Specific(X), m_SpecificInt(LowMask)))))) ||
+        CrossSum->hasNUsesOrMore(4) || XhYl->hasNUsesOrMore(3))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  // X and Y are the two inputs, A, B and C are other parts of the pattern
+  // (crosssum>>32, carry, etc).
+  Value *X, *Y;
+  Instruction *A, *B, *C;
+  auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)),
+                             m_LShr(m_Value(Y), m_SpecificInt(BW / 2))));
+  if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
+                                              m_Instruction(B))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
+      A->hasOneUse() && B->hasOneUse())
+    if (foldMulHighCarry(X, Y, A, B) || foldMulHighLadder(X, Y, A, B))
+      return true;
+
+  if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
+                                   m_Instruction(A),
+                                   m_OneUse(m_Add(m_Instruction(B),
+                                                  m_Instruction(C))))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(
+                             HiHi, m_OneUse(m_Add(m_Instruction(B),
+                                                  m_Instruction(C))))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(
+                             m_Instruction(B),
+                             m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
+       match(&I,
+             m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
+                     m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
+      A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
+    return foldMulHighCarry4(X, Y, A, B, C) ||
+           foldMulHighLadder4(X, Y, A, B, C);
+
+  return false;
+}
+
 /// This is the entry point for folds that could be implemented in regular
 /// InstCombine, but they are separated because they are not expected to
 /// occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1795,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
       MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
       MadeChange |= foldPatternedLoads(I, DL);
       MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
+      MadeChange |= foldMulHigh(I);
       // NOTE: This function introduces erasing of the instruction `I`, so it
       // needs to be called at the end of this sequence, otherwise we may make
       // bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
new file mode 100644
index 0000000000000..b78095cac0df9
--- /dev/null
+++ b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
@@ -0,0 +1,755 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=aggressive-instcombine,instcombine -S | FileCheck %s
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i32 @mul_carry(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %shr, %and2
+  %mul3 = mul nuw i32 %and, %shr1
+  %add = add i32 %mul, %mul3
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %add, %shr5
+  %cmp = icmp ult i32 %add6, %mul
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %add9 = add nuw i32 %mul8, %cond
+  %shr10 = lshr i32 %add6, 16
+  %add11 = add i32 %add9, %shr10
+  ret i32 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i128 @mul_carry_i128(i128 %x, i128 %y) {
+; CHECK-LABEL: define i128 @mul_carry_i128(
+; CHECK-SAME: i128 [[X:%.*]], i128 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i128 [[X]] to i256
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i128 [[Y]] to i256
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i256 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i256 [[TMP2]], 128
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i256 [[TMP3]] to i128
+; CHECK-NEXT:    ret i128 [[ADD11]]
+;
+entry:
+  %shr = lshr i128 %x, 64
+  %and = and i128 %x, u0xffffffffffffffff
+  %shr1 = lshr i128 %y, 64
+  %and2 = and i128 %y, u0xffffffffffffffff
+  %mul = mul nuw i128 %shr, %and2
+  %mul3 = mul nuw i128 %and, %shr1
+  %add = add i128 %mul, %mul3
+  %mul4 = mul nuw i128 %and, %and2
+  %shr5 = lshr i128 %mul4, 64
+  %add6 = add i128 %add, %shr5
+  %cmp = icmp ult i128 %add6, %mul
+  %cond = select i1 %cmp, i128 u0x10000000000000000, i128 0
+  %mul8 = mul nuw i128 %shr, %shr1
+  %add9 = add nuw i128 %mul8, %cond
+  %shr10 = lshr i128 %add6, 64
+  %add11 = add i128 %add9, %shr10
+  ret i128 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define <4 x i32> @mul_carry_v4i32(<4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: define <4 x i32> @mul_carry_v4i32(
+; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i32> [[Y]] to <4 x i64>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw <4 x i64> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i64> [[TMP2]], splat (i64 32)
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw <4 x i64> [[TMP3]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[ADD11]]
+;
+entry:
+  %shr = lshr <4 x i32> %x, <i32 16, i32 16, i32 16, i32 16>
+  %and = and <4 x i32> %x, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %shr1 = lshr <4 x i32> %y, <i32 16, i32 16, i32 16, i32 16>
+  %and2 = and <4 x i32> %y, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %mul = mul nuw <4 x i32> %shr, %and2
+  %mul3 = mul nuw <4 x i32> %and, %shr1
+  %add = add <4 x i32> %mul, %mul3
+  %mul4 = mul nuw <4 x i32> %and, %and2
+  %shr5 = lshr <4 x i32> %mul4, <i32 16, i32 16, i32 16, i32 16>
+  %add6 = add <4 x i32> %add, %shr5
+  %cmp = icmp ult <4 x i32> %add6, %mul
+  %cond = select <4 x i1> %cmp, <4 x i32> <i32 65536, i32 65536, i32 65536, i32 65536>, <4 x i32> zeroinitializer
+  %mul8 = mul nuw <4 x i32> %shr, %shr1
+  %add9 = add nuw <4 x i32> %mul8, %cond
+  %shr10 = lshr <4 x i32> %add6, <i32 16, i32 16, i32 16, i32 16>
+  %add11 = add <4 x i32> %add9, %shr10
+  ret <4 x i32> %add11
+}
+
+; Check carry against xlyh, not xhyl
+define i32 @mul_carry_xlyh(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_xlyh(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %shr, %and2
+  %mul3 = mul nuw i32 %and, %shr1
+  %add = add i32 %mul, %mul3
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %add, %shr5
+  %cmp = icmp ult i32 %add6, %mul3
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %add9 = add nuw i32 %mul8, %cond
+  %shr10 = lshr i32 %add6, 16
+  %add11 = add i32 %add9, %shr10
+  ret i32 %add11
+}
+
+define i32 @mul_carry_comm(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %and2, %shr
+  %mul3 = mul nuw i32 %shr1, %and
+  %add = add i32 %mul3, %mul
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %shr5, %add
+  %cmp = icmp ult i32 %add6, %mul
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %shr10 = lshr i32 %add6, 16
+  %add9 = add nuw i32 %cond, %shr10
+  %add11 = add i32 %add9, %mul8
+  ret i32 %add11
+}
+
+
+; Negative tests
+
+
+define i32 @mul_carry_notxlo(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_notxlo(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[SHR:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X]], 32767
+; CHECK-NEXT:    [[SHR1:%.*]] = lshr i32 [[Y]], 16
+; CHECK-NEXT:    [[AND2:%.*]] = and i32 [[Y]], 65535
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[SHR]], [[AND2]]
+; CHECK-NEXT:    [[MUL3:%.*]] = mul nuw nsw i32 [[AND]], [[SHR1]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL]], [[MUL3]]
+; CHECK-NEXT:    [[MUL4:%.*]] = mul nuw nsw i32 [[AND]], [[AND2]]
+; CHECK-NEXT:    [[SHR5:%.*]] = lshr i32 [[MUL4]], 16
+; CHECK-NEXT:    [[ADD6:%.*]] = add i32 [[ADD]], [[SHR5]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD6]], [[MUL]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 65536, i32 0
+; CHECK-NEXT:    [[MUL8:%.*]] = mul nuw i32 [[SHR]], [[SHR1]]
+; CHECK-NE...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/168396


More information about the llvm-commits mailing list