[llvm] 1f3e35b - [AggressiveInstCombine] Add shift left instruction to `TruncInstCombine` DAG
Anton Afanasyev via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 17 03:17:13 PDT 2021
Author: Anton Afanasyev
Date: 2021-08-17T12:44:37+03:00
New Revision: 1f3e35b6d165715ec7bf7ba80d5b982719c7752a
URL: https://github.com/llvm/llvm-project/commit/1f3e35b6d165715ec7bf7ba80d5b982719c7752a
DIFF: https://github.com/llvm/llvm-project/commit/1f3e35b6d165715ec7bf7ba80d5b982719c7752a.diff
LOG: [AggressiveInstCombine] Add shift left instruction to `TruncInstCombine` DAG
Add `shl` instruction to the DAG post-dominated by `trunc`, allowing
TruncInstCombine to reduce bitwidth of expressions containing left shifts.
The only thing we need to check is that the target bitwidth must be wider
than the maximal shift amount: https://alive2.llvm.org/ce/z/AwArqu
Part of https://reviews.llvm.org/D107766
Differential Revision: https://reviews.llvm.org/D108091
Added:
Modified:
llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 16b82219e8ca3..b614cfd7b9b09 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -29,10 +29,12 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/Support/KnownBits.h"
using namespace llvm;
@@ -61,6 +63,7 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
Ops.push_back(I->getOperand(0));
Ops.push_back(I->getOperand(1));
break;
@@ -127,6 +130,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
case Instruction::Select: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
@@ -137,7 +141,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
// TODO: Can handle more cases here:
// 1. shufflevector, extractelement, insertelement
// 2. udiv, urem
- // 3. shl, lshr, ashr
+ // 3. lshr, ashr
// 4. phi node(and loop handling)
// ...
return false;
@@ -270,6 +274,23 @@ Type *TruncInstCombine::getBestTruncatedType() {
unsigned OrigBitWidth =
CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
+ // Initialize MinBitWidth for `shl` instructions with the minimum number
+ // that is greater than shift amount (i.e. shift amount + 1).
+ // Also normalize MinBitWidth not to be greater than source bitwidth.
+ for (auto &Itr : InstInfoMap) {
+ Instruction *I = Itr.first;
+ if (I->getOpcode() == Instruction::Shl) {
+ KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL);
+ const unsigned SrcBitWidth = KnownRHS.getBitWidth();
+ unsigned MinBitWidth =
+ KnownRHS.getMaxValue().uadd_sat(APInt(SrcBitWidth, 1)).getZExtValue();
+ MinBitWidth = std::min(MinBitWidth, SrcBitWidth);
+ if (MinBitWidth >= OrigBitWidth)
+ return nullptr;
+ Itr.second.MinBitWidth = MinBitWidth;
+ }
+ }
+
// Calculate minimum allowed bit-width allowed for shrinking the currently
// visited truncate's operand.
unsigned MinBitWidth = getMinBitWidth();
@@ -356,7 +377,8 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
case Instruction::Mul:
case Instruction::And:
case Instruction::Or:
- case Instruction::Xor: {
+ case Instruction::Xor:
+ case Instruction::Shl: {
Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
diff --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll
index 67d78293564e7..e7f491aa51054 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll
@@ -3,10 +3,9 @@
define i16 @shl_1(i8 %x) {
; CHECK-LABEL: @shl_1(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT]], 1
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16
-; CHECK-NEXT: ret i16 [[TRUNC]]
+; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT]], 1
+; CHECK-NEXT: ret i16 [[SHL]]
;
%zext = zext i8 %x to i32
%shl = shl i32 %zext, 1
@@ -16,10 +15,9 @@ define i16 @shl_1(i8 %x) {
define i16 @shl_15(i8 %x) {
; CHECK-LABEL: @shl_15(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT]], 15
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16
-; CHECK-NEXT: ret i16 [[TRUNC]]
+; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT]], 15
+; CHECK-NEXT: ret i16 [[SHL]]
;
%zext = zext i8 %x to i32
%shl = shl i32 %zext, 15
@@ -61,12 +59,11 @@ define i16 @shl_var_shift_amount(i8 %x, i8 %y) {
define i16 @shl_var_bounded_shift_amount(i8 %x, i8 %y) {
; CHECK-LABEL: @shl_var_bounded_shift_amount(
-; CHECK-NEXT: [[ZEXT_X:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[ZEXT_Y:%.*]] = zext i8 [[Y:%.*]] to i32
-; CHECK-NEXT: [[AND:%.*]] = and i32 [[ZEXT_Y]], 15
-; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT_X]], [[AND]]
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16
-; CHECK-NEXT: ret i16 [[TRUNC]]
+; CHECK-NEXT: [[ZEXT_X:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT: [[ZEXT_Y:%.*]] = zext i8 [[Y:%.*]] to i16
+; CHECK-NEXT: [[AND:%.*]] = and i16 [[ZEXT_Y]], 15
+; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT_X]], [[AND]]
+; CHECK-NEXT: ret i16 [[SHL]]
;
%zext.x = zext i8 %x to i32
%zext.y = zext i8 %y to i32
@@ -78,10 +75,9 @@ define i16 @shl_var_bounded_shift_amount(i8 %x, i8 %y) {
define <2 x i16> @shl_vector(<2 x i8> %x) {
; CHECK-LABEL: @shl_vector(
-; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT: [[S:%.*]] = shl <2 x i32> [[Z]], <i32 4, i32 10>
-; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S]] to <2 x i16>
-; CHECK-NEXT: ret <2 x i16> [[T]]
+; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
+; CHECK-NEXT: [[S:%.*]] = shl <2 x i16> [[Z]], <i16 4, i16 10>
+; CHECK-NEXT: ret <2 x i16> [[S]]
;
%z = zext <2 x i8> %x to <2 x i32>
%s = shl <2 x i32> %z, <i32 4, i32 10>
@@ -121,10 +117,9 @@ define <2 x i16> @shl_vector_large_shift_amount(<2 x i8> %x) {
define i16 @shl_nuw(i8 %x) {
; CHECK-LABEL: @shl_nuw(
-; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[S:%.*]] = shl nuw i32 [[Z]], 15
-; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S]] to i16
-; CHECK-NEXT: ret i16 [[T]]
+; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT: [[S:%.*]] = shl i16 [[Z]], 15
+; CHECK-NEXT: ret i16 [[S]]
;
%z = zext i8 %x to i32
%s = shl nuw i32 %z, 15
@@ -134,10 +129,9 @@ define i16 @shl_nuw(i8 %x) {
define i16 @shl_nsw(i8 %x) {
; CHECK-LABEL: @shl_nsw(
-; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[S:%.*]] = shl nsw i32 [[Z]], 15
-; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S]] to i16
-; CHECK-NEXT: ret i16 [[T]]
+; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT: [[S:%.*]] = shl i16 [[Z]], 15
+; CHECK-NEXT: ret i16 [[S]]
;
%z = zext i8 %x to i32
%s = shl nsw i32 %z, 15
More information about the llvm-commits
mailing list