[llvm] [AggressiveInstCombine] Match long high-half multiply (PR #168396)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 17 08:21:08 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: David Green (davemgreen)
<details>
<summary>Changes</summary>
This patch adds recognition of high-half multiply by parts into a single larger
multiply.
Considering a multiply made up of high and low parts, we can split the
multiply into:
```
x * y == (xh*T + xl) * (yh*T + yl)
```
where `xh == x>>32` and `xl == x & 0xffffffff`. `T = 2^32`.
This expands to
```
xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
```
which I find it helpful to be drawn as
```
[ xh*yh ]
[ xh*yl ]
[ xl*yh ]
[ xl*yl ]
```
We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:
```
Carry: xh*yh + carry + lowsum
carry = lowsum < xh*yl ? 0x1000000 : 0
lowsum = xh*yl + xl*yh + (xl*yl>>32)
Ladder: xh*yh + c2>>32 + c3>>32
c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
crosssum = xh*yl + xl*yh
carry = crosssum < xh*yl ? 0x1000000 : 0
Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
```
They all start by matching `xh*yh` + 2 or 3 other operands. The bottom of the
tree is `xh*yh`, `xh*yl`, `xl*yh` and `xl*yl`.
Based on https://github.com/llvm/llvm-project/pull/156879 by @<!-- -->c-rhodes
---
Patch is 226.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168396.diff
5 Files Affected:
- (modified) llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (+301)
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll (+755)
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry4.ll (+3019)
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder.ll (+818)
- (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder4.ll (+530)
``````````diff
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index b575d76e897d2..fb71f57eaa502 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+/// x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [ xh*yh ]
+/// [ xh*yl ]
+/// [ xl*yh ]
+/// [ xl*yl ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+/// Carry: xh*yh + carry + lowsum
+/// carry = lowsum < xh*yl ? 0x1000000 : 0
+/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
+/// Ladder: xh*yh + c2>>32 + c3>>32
+/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+/// crosssum = xh*yl + xl*yh
+/// carry = crosssum < xh*yl ? 0x1000000 : 0
+/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+ Type *Ty = I.getType();
+ if (!Ty->isIntOrIntVectorTy())
+ return false;
+
+ unsigned BW = Ty->getScalarSizeInBits();
+ APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+ if (BW % 2 != 0)
+ return false;
+
+ auto CreateMulHigh = [&](Value *X, Value *Y) {
+ IRBuilder<> Builder(&I);
+ Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+ Value *XExt = Builder.CreateZExt(X, NTy);
+ Value *YExt = Builder.CreateZExt(Y, NTy);
+ Value *Mul = Builder.CreateMul(XExt, YExt);
+ Value *High = Builder.CreateLShr(Mul, BW);
+ Value *Res = Builder.CreateTrunc(High, Ty);
+ I.replaceAllUsesWith(Res);
+ LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+ << *Y << "\n");
+ return true;
+ };
+
+ // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+ auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+ return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+ auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+ return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+
+ auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+ Instruction *B) {
+ // Looking for LowSum >> 32 and carry (select)
+ if (Carry->getOpcode() != Instruction::Select)
+ std::swap(Carry, B);
+
+ // Carry = LowSum < XhYl ? 0x100000000 : 0
+ CmpPredicate Pred;
+ Value *LowSum, *XhYl;
+ if (!match(Carry,
+ m_OneUse(m_Select(
+ m_OneUse(m_ICmp(Pred, m_Value(LowSum), m_Value(XhYl))),
+ m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+ Pred != ICmpInst::ICMP_ULT)
+ return false;
+
+ // XhYl can be Xh*Yl or Xl*Yh
+ if (!CheckHiLo(XhYl, X, Y)) {
+ if (CheckHiLo(XhYl, Y, X))
+ std::swap(X, Y);
+ else
+ return false;
+ }
+ if (XhYl->hasNUsesOrMore(3))
+ return false;
+
+ // B = LowSum >> 16
+ if (!match(B,
+ m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
+ LowSum->hasNUsesOrMore(3))
+ return false;
+
+ // LowSum = XhYl + XlYh + XlYl>>32
+ Value *XlYh, *XlYl;
+ auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
+ if (!match(LowSum,
+ m_c_Add(m_Specific(XhYl),
+ m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
+ !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
+ m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
+ !match(LowSum,
+ m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
+ m_OneUse(m_Value(XlYh)))))))
+ return false;
+
+ // Check XlYl and XlYh
+ if (!CheckLoLo(XlYl, X, Y))
+ return false;
+ if (!CheckHiLo(XlYh, Y, X))
+ return false;
+
+ return CreateMulHigh(X, Y);
+ };
+
+ auto foldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
+ Instruction *B) {
+ // xh*yh + c2>>32 + c3>>32
+ // c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
+ Value *XlYh, *XhYl, *C2, *C3;
+ // Strip off the two expected shifts.
+ if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
+ !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
+ return false;
+
+ // Match c3 = c2&0xffffffff + xl*yh
+ if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
+ m_Value(XhYl))))
+ std::swap(C2, C3);
+ if (!match(C3,
+ m_c_Add(m_OneUse(m_And(m_Specific(C2), m_SpecificInt(LowMask))),
+ m_Value(XhYl))) ||
+ !C3->hasOneUse() || C2->hasNUsesOrMore(3))
+ return false;
+
+ // Match c2 = xh*yl + (xl*yl >> 32)
+ Value *XlYl;
+ if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
+ m_Value(XlYh))))
+ return false;
+
+ // Match XhYl and XlYh - they can appear either way around.
+ if (!CheckHiLo(XlYh, Y, X))
+ std::swap(XlYh, XhYl);
+ if (!CheckHiLo(XlYh, Y, X))
+ return false;
+ if (!CheckHiLo(XhYl, X, Y))
+ return false;
+ if (!CheckLoLo(XlYl, X, Y))
+ return false;
+
+ return CreateMulHigh(X, Y);
+ };
+
+ auto foldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
+ Instruction *B, Instruction *C) {
+ /// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
+ /// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+
+ // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
+ auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2));
+ if (!match(A, ShiftAdd))
+ std::swap(A, B);
+ if (!match(A, ShiftAdd))
+ std::swap(A, C);
+ Value *Low;
+ if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2))))
+ return false;
+
+ // Match B == XhYl>>32 and C == XlYh>>32
+ Value *XhYl, *XlYh;
+ if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) ||
+ !match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2))))
+ return false;
+ if (!CheckHiLo(XhYl, X, Y))
+ std::swap(XhYl, XlYh);
+ if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
+ return false;
+ if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
+ return false;
+
+ // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
+ Value *XlYl;
+ if (!match(
+ Low,
+ m_c_Add(
+ m_OneUse(m_c_Add(
+ m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+ m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
+ m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) &&
+ !match(
+ Low,
+ m_c_Add(
+ m_OneUse(m_c_Add(
+ m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+ m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+ m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
+ !match(
+ Low,
+ m_c_Add(
+ m_OneUse(m_c_Add(
+ m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
+ m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+ m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
+ return false;
+ if (!CheckLoLo(XlYl, X, Y))
+ return false;
+
+ return CreateMulHigh(X, Y);
+ };
+
+ auto foldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
+ Instruction *B, Instruction *C) {
+ // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+ // crosssum = xh*yl+xl*yh
+ // carry = crosssum < xh*yl ? 0x1000000 : 0
+ if (Carry->getOpcode() != Instruction::Select)
+ std::swap(Carry, B);
+ if (Carry->getOpcode() != Instruction::Select)
+ std::swap(Carry, C);
+
+ // Carry = CrossSum < XhYl ? 0x100000000 : 0
+ CmpPredicate Pred;
+ Value *CrossSum, *XhYl;
+ if (!match(Carry,
+ m_OneUse(m_Select(
+ m_OneUse(m_ICmp(Pred, m_Value(CrossSum), m_Value(XhYl))),
+ m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+ Pred != ICmpInst::ICMP_ULT)
+ return false;
+
+ if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+ std::swap(B, C);
+ if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+ return false;
+
+ Value *XlYl, *LowAccum;
+ if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) ||
+ !match(LowAccum,
+ m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
+ m_OneUse(m_And(m_Specific(CrossSum),
+ m_SpecificInt(LowMask))))) ||
+ LowAccum->hasNUsesOrMore(3))
+ return false;
+ if (!CheckLoLo(XlYl, X, Y))
+ return false;
+
+ if (!CheckHiLo(XhYl, X, Y))
+ std::swap(X, Y);
+ if (!CheckHiLo(XhYl, X, Y))
+ return false;
+ if (!match(CrossSum,
+ m_c_Add(m_Specific(XhYl),
+ m_OneUse(m_c_Mul(
+ m_LShr(m_Specific(Y), m_SpecificInt(BW / 2)),
+ m_And(m_Specific(X), m_SpecificInt(LowMask)))))) ||
+ CrossSum->hasNUsesOrMore(4) || XhYl->hasNUsesOrMore(3))
+ return false;
+
+ return CreateMulHigh(X, Y);
+ };
+
+ // X and Y are the two inputs, A, B and C are other parts of the pattern
+ // (crosssum>>32, carry, etc).
+ Value *X, *Y;
+ Instruction *A, *B, *C;
+ auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)),
+ m_LShr(m_Value(Y), m_SpecificInt(BW / 2))));
+ if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
+ m_Instruction(B))))) ||
+ match(&I, m_c_Add(m_Instruction(A),
+ m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
+ A->hasOneUse() && B->hasOneUse())
+ if (foldMulHighCarry(X, Y, A, B) || foldMulHighLadder(X, Y, A, B))
+ return true;
+
+ if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
+ m_Instruction(A),
+ m_OneUse(m_Add(m_Instruction(B),
+ m_Instruction(C))))))) ||
+ match(&I, m_c_Add(m_Instruction(A),
+ m_OneUse(m_c_Add(
+ HiHi, m_OneUse(m_Add(m_Instruction(B),
+ m_Instruction(C))))))) ||
+ match(&I, m_c_Add(m_Instruction(A),
+ m_OneUse(m_c_Add(
+ m_Instruction(B),
+ m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
+ match(&I,
+ m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
+ m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
+ A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
+ return foldMulHighCarry4(X, Y, A, B, C) ||
+ foldMulHighLadder4(X, Y, A, B, C);
+
+ return false;
+}
+
/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1795,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
MadeChange |= foldPatternedLoads(I, DL);
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
+ MadeChange |= foldMulHigh(I);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
new file mode 100644
index 0000000000000..b78095cac0df9
--- /dev/null
+++ b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
@@ -0,0 +1,755 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=aggressive-instcombine,instcombine -S | FileCheck %s
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i32 @mul_carry(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT: [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT: ret i32 [[ADD11]]
+;
+entry:
+ %shr = lshr i32 %x, 16
+ %and = and i32 %x, 65535
+ %shr1 = lshr i32 %y, 16
+ %and2 = and i32 %y, 65535
+ %mul = mul nuw i32 %shr, %and2
+ %mul3 = mul nuw i32 %and, %shr1
+ %add = add i32 %mul, %mul3
+ %mul4 = mul nuw i32 %and, %and2
+ %shr5 = lshr i32 %mul4, 16
+ %add6 = add i32 %add, %shr5
+ %cmp = icmp ult i32 %add6, %mul
+ %cond = select i1 %cmp, i32 65536, i32 0
+ %mul8 = mul nuw i32 %shr, %shr1
+ %add9 = add nuw i32 %mul8, %cond
+ %shr10 = lshr i32 %add6, 16
+ %add11 = add i32 %add9, %shr10
+ ret i32 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i128 @mul_carry_i128(i128 %x, i128 %y) {
+; CHECK-LABEL: define i128 @mul_carry_i128(
+; CHECK-SAME: i128 [[X:%.*]], i128 [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext i128 [[X]] to i256
+; CHECK-NEXT: [[TMP1:%.*]] = zext i128 [[Y]] to i256
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i256 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i256 [[TMP2]], 128
+; CHECK-NEXT: [[ADD11:%.*]] = trunc nuw i256 [[TMP3]] to i128
+; CHECK-NEXT: ret i128 [[ADD11]]
+;
+entry:
+ %shr = lshr i128 %x, 64
+ %and = and i128 %x, u0xffffffffffffffff
+ %shr1 = lshr i128 %y, 64
+ %and2 = and i128 %y, u0xffffffffffffffff
+ %mul = mul nuw i128 %shr, %and2
+ %mul3 = mul nuw i128 %and, %shr1
+ %add = add i128 %mul, %mul3
+ %mul4 = mul nuw i128 %and, %and2
+ %shr5 = lshr i128 %mul4, 64
+ %add6 = add i128 %add, %shr5
+ %cmp = icmp ult i128 %add6, %mul
+ %cond = select i1 %cmp, i128 u0x10000000000000000, i128 0
+ %mul8 = mul nuw i128 %shr, %shr1
+ %add9 = add nuw i128 %mul8, %cond
+ %shr10 = lshr i128 %add6, 64
+ %add11 = add i128 %add9, %shr10
+ ret i128 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define <4 x i32> @mul_carry_v4i32(<4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: define <4 x i32> @mul_carry_v4i32(
+; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i32> [[Y]] to <4 x i64>
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw <4 x i64> [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i64> [[TMP2]], splat (i64 32)
+; CHECK-NEXT: [[ADD11:%.*]] = trunc nuw <4 x i64> [[TMP3]] to <4 x i32>
+; CHECK-NEXT: ret <4 x i32> [[ADD11]]
+;
+entry:
+ %shr = lshr <4 x i32> %x, <i32 16, i32 16, i32 16, i32 16>
+ %and = and <4 x i32> %x, <i32 65535, i32 65535, i32 65535, i32 65535>
+ %shr1 = lshr <4 x i32> %y, <i32 16, i32 16, i32 16, i32 16>
+ %and2 = and <4 x i32> %y, <i32 65535, i32 65535, i32 65535, i32 65535>
+ %mul = mul nuw <4 x i32> %shr, %and2
+ %mul3 = mul nuw <4 x i32> %and, %shr1
+ %add = add <4 x i32> %mul, %mul3
+ %mul4 = mul nuw <4 x i32> %and, %and2
+ %shr5 = lshr <4 x i32> %mul4, <i32 16, i32 16, i32 16, i32 16>
+ %add6 = add <4 x i32> %add, %shr5
+ %cmp = icmp ult <4 x i32> %add6, %mul
+ %cond = select <4 x i1> %cmp, <4 x i32> <i32 65536, i32 65536, i32 65536, i32 65536>, <4 x i32> zeroinitializer
+ %mul8 = mul nuw <4 x i32> %shr, %shr1
+ %add9 = add nuw <4 x i32> %mul8, %cond
+ %shr10 = lshr <4 x i32> %add6, <i32 16, i32 16, i32 16, i32 16>
+ %add11 = add <4 x i32> %add9, %shr10
+ ret <4 x i32> %add11
+}
+
+; Check carry against xlyh, not xhyl
+define i32 @mul_carry_xlyh(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_xlyh(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT: [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT: ret i32 [[ADD11]]
+;
+entry:
+ %shr = lshr i32 %x, 16
+ %and = and i32 %x, 65535
+ %shr1 = lshr i32 %y, 16
+ %and2 = and i32 %y, 65535
+ %mul = mul nuw i32 %shr, %and2
+ %mul3 = mul nuw i32 %and, %shr1
+ %add = add i32 %mul, %mul3
+ %mul4 = mul nuw i32 %and, %and2
+ %shr5 = lshr i32 %mul4, 16
+ %add6 = add i32 %add, %shr5
+ %cmp = icmp ult i32 %add6, %mul3
+ %cond = select i1 %cmp, i32 65536, i32 0
+ %mul8 = mul nuw i32 %shr, %shr1
+ %add9 = add nuw i32 %mul8, %cond
+ %shr10 = lshr i32 %add6, 16
+ %add11 = add i32 %add9, %shr10
+ ret i32 %add11
+}
+
+define i32 @mul_carry_comm(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT: [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT: ret i32 [[ADD11]]
+;
+entry:
+ %shr = lshr i32 %x, 16
+ %and = and i32 %x, 65535
+ %shr1 = lshr i32 %y, 16
+ %and2 = and i32 %y, 65535
+ %mul = mul nuw i32 %and2, %shr
+ %mul3 = mul nuw i32 %shr1, %and
+ %add = add i32 %mul3, %mul
+ %mul4 = mul nuw i32 %and, %and2
+ %shr5 = lshr i32 %mul4, 16
+ %add6 = add i32 %shr5, %add
+ %cmp = icmp ult i32 %add6, %mul
+ %cond = select i1 %cmp, i32 65536, i32 0
+ %mul8 = mul nuw i32 %shr, %shr1
+ %shr10 = lshr i32 %add6, 16
+ %add9 = add nuw i32 %cond, %shr10
+ %add11 = add i32 %add9, %mul8
+ ret i32 %add11
+}
+
+
+; Negative tests
+
+
+define i32 @mul_carry_notxlo(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_notxlo(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], 32767
+; CHECK-NEXT: [[SHR1:%.*]] = lshr i32 [[Y]], 16
+; CHECK-NEXT: [[AND2:%.*]] = and i32 [[Y]], 65535
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[SHR]], [[AND2]]
+; CHECK-NEXT: [[MUL3:%.*]] = mul nuw nsw i32 [[AND]], [[SHR1]]
+; CHECK-NEXT: [[ADD:%.*]] = add i32 [[MUL]], [[MUL3]]
+; CHECK-NEXT: [[MUL4:%.*]] = mul nuw nsw i32 [[AND]], [[AND2]]
+; CHECK-NEXT: [[SHR5:%.*]] = lshr i32 [[MUL4]], 16
+; CHECK-NEXT: [[ADD6:%.*]] = add i32 [[ADD]], [[SHR5]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ADD6]], [[MUL]]
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 65536, i32 0
+; CHECK-NEXT: [[MUL8:%.*]] = mul nuw i32 [[SHR]], [[SHR1]]
+; CHECK-NE...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/168396
More information about the llvm-commits
mailing list