[llvm] 8efee51 - [InstCombine] limit pair-of-insertelement folds to avoid miscompile
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 15 05:28:21 PST 2022
Author: Sanjay Patel
Date: 2022-12-15T08:27:43-05:00
New Revision: 8efee510be3c2dd02db0b070055aedf095a3acce
URL: https://github.com/llvm/llvm-project/commit/8efee510be3c2dd02db0b070055aedf095a3acce
DIFF: https://github.com/llvm/llvm-project/commit/8efee510be3c2dd02db0b070055aedf095a3acce.diff
LOG: [InstCombine] limit pair-of-insertelement folds to avoid miscompile
This transform was added with 4446f71ce392. However, as noted in
the post-commit feedback, the transform is not safe with an
arbitrary base vector because we may leak poison from a narrow
element into an adjacent element when bitcasting.
I made the least invasive code change in case we do figure out
a way to make this safe.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
llvm/test/Transforms/InstCombine/insertelt-trunc.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 099ad3dfafe2c..84d3134c6694d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1525,6 +1525,9 @@ static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt,
Value *IndexOp = InsElt.getOperand(2);
// inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1
+ // Note: It is not safe to do this transform with an arbitrary base vector
+ // because the bitcast of that vector to fewer/larger elements could
+ // allow poison to spill into an element that was not poison before.
// TODO: The insertion order could be reversed.
// TODO: Detect smaller fractions of the scalar.
// TODO: One-use checks are conservative.
@@ -1536,7 +1539,8 @@ static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt,
m_ConstantInt(Index0)))) ||
!match(ScalarOp, m_OneUse(m_Trunc(m_LShr(m_Specific(X),
m_ConstantInt(ShAmt))))) ||
- !match(IndexOp, m_ConstantInt(Index1)))
+ !match(IndexOp, m_ConstantInt(Index1)) ||
+ !match(BaseVec, m_Undef()))
return nullptr;
Type *SrcTy = X->getType();
diff --git a/llvm/test/Transforms/InstCombine/insertelt-trunc.ll b/llvm/test/Transforms/InstCombine/insertelt-trunc.ll
index bd829d5526d1a..cd735b85d242d 100644
--- a/llvm/test/Transforms/InstCombine/insertelt-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/insertelt-trunc.ll
@@ -158,20 +158,16 @@ define <2 x i16> @insert_01_v2i16(i32 %x, <2 x i16> %v) {
ret <2 x i16> %ins1
}
+; negative test - can't do this safely without knowing something about the base vector
+
define <8 x i16> @insert_10_v8i16(i32 %x, <8 x i16> %v) {
-; BE-LABEL: @insert_10_v8i16(
-; BE-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[V:%.*]] to <4 x i32>
-; BE-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X:%.*]], i64 0
-; BE-NEXT: [[INS1:%.*]] = bitcast <4 x i32> [[TMP2]] to <8 x i16>
-; BE-NEXT: ret <8 x i16> [[INS1]]
-;
-; LE-LABEL: @insert_10_v8i16(
-; LE-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
-; LE-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
-; LE-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
-; LE-NEXT: [[INS0:%.*]] = insertelement <8 x i16> [[V:%.*]], i16 [[LO16]], i64 1
-; LE-NEXT: [[INS1:%.*]] = insertelement <8 x i16> [[INS0]], i16 [[HI16]], i64 0
-; LE-NEXT: ret <8 x i16> [[INS1]]
+; ALL-LABEL: @insert_10_v8i16(
+; ALL-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
+; ALL-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
+; ALL-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
+; ALL-NEXT: [[INS0:%.*]] = insertelement <8 x i16> [[V:%.*]], i16 [[LO16]], i64 1
+; ALL-NEXT: [[INS1:%.*]] = insertelement <8 x i16> [[INS0]], i16 [[HI16]], i64 0
+; ALL-NEXT: ret <8 x i16> [[INS1]]
;
%hi32 = lshr i32 %x, 16
%hi16 = trunc i32 %hi32 to i16
@@ -219,20 +215,16 @@ define <4 x i16> @insert_21_v4i16(i32 %x, <4 x i16> %v) {
ret <4 x i16> %ins1
}
+; negative test - can't do this safely without knowing something about the base vector
+
define <4 x i32> @insert_23_v4i32(i64 %x, <4 x i32> %v) {
-; BE-LABEL: @insert_23_v4i32(
-; BE-NEXT: [[HI64:%.*]] = lshr i64 [[X:%.*]], 32
-; BE-NEXT: [[HI32:%.*]] = trunc i64 [[HI64]] to i32
-; BE-NEXT: [[LO32:%.*]] = trunc i64 [[X]] to i32
-; BE-NEXT: [[INS0:%.*]] = insertelement <4 x i32> [[V:%.*]], i32 [[LO32]], i64 2
-; BE-NEXT: [[INS1:%.*]] = insertelement <4 x i32> [[INS0]], i32 [[HI32]], i64 3
-; BE-NEXT: ret <4 x i32> [[INS1]]
-;
-; LE-LABEL: @insert_23_v4i32(
-; LE-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[V:%.*]] to <2 x i64>
-; LE-NEXT: [[TMP2:%.*]] = insertelement <2 x i64> [[TMP1]], i64 [[X:%.*]], i64 1
-; LE-NEXT: [[INS1:%.*]] = bitcast <2 x i64> [[TMP2]] to <4 x i32>
-; LE-NEXT: ret <4 x i32> [[INS1]]
+; ALL-LABEL: @insert_23_v4i32(
+; ALL-NEXT: [[HI64:%.*]] = lshr i64 [[X:%.*]], 32
+; ALL-NEXT: [[HI32:%.*]] = trunc i64 [[HI64]] to i32
+; ALL-NEXT: [[LO32:%.*]] = trunc i64 [[X]] to i32
+; ALL-NEXT: [[INS0:%.*]] = insertelement <4 x i32> [[V:%.*]], i32 [[LO32]], i64 2
+; ALL-NEXT: [[INS1:%.*]] = insertelement <4 x i32> [[INS0]], i32 [[HI32]], i64 3
+; ALL-NEXT: ret <4 x i32> [[INS1]]
;
%hi64 = lshr i64 %x, 32
%hi32 = trunc i64 %hi64 to i32
@@ -242,20 +234,16 @@ define <4 x i32> @insert_23_v4i32(i64 %x, <4 x i32> %v) {
ret <4 x i32> %ins1
}
+; negative test - can't do this safely without knowing something about the base vector
+
define <4 x i16> @insert_32_v4i16(i32 %x, <4 x i16> %v) {
-; BE-LABEL: @insert_32_v4i16(
-; BE-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[V:%.*]] to <2 x i32>
-; BE-NEXT: [[TMP2:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[X:%.*]], i64 1
-; BE-NEXT: [[INS1:%.*]] = bitcast <2 x i32> [[TMP2]] to <4 x i16>
-; BE-NEXT: ret <4 x i16> [[INS1]]
-;
-; LE-LABEL: @insert_32_v4i16(
-; LE-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
-; LE-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
-; LE-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
-; LE-NEXT: [[INS0:%.*]] = insertelement <4 x i16> [[V:%.*]], i16 [[LO16]], i64 3
-; LE-NEXT: [[INS1:%.*]] = insertelement <4 x i16> [[INS0]], i16 [[HI16]], i64 2
-; LE-NEXT: ret <4 x i16> [[INS1]]
+; ALL-LABEL: @insert_32_v4i16(
+; ALL-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
+; ALL-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
+; ALL-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
+; ALL-NEXT: [[INS0:%.*]] = insertelement <4 x i16> [[V:%.*]], i16 [[LO16]], i64 3
+; ALL-NEXT: [[INS1:%.*]] = insertelement <4 x i16> [[INS0]], i16 [[HI16]], i64 2
+; ALL-NEXT: ret <4 x i16> [[INS1]]
;
%hi32 = lshr i32 %x, 16
%hi16 = trunc i32 %hi32 to i16
@@ -324,25 +312,18 @@ define <8 x i16> @insert_67_v4i16_uses1(i32 %x, <8 x i16> %v) {
ret <8 x i16> %ins1
}
-; extra use is ok
+; negative test - can't do this safely without knowing something about the base vector
+; extra use would be ok
define <8 x i16> @insert_76_v4i16_uses2(i32 %x, <8 x i16> %v) {
-; BE-LABEL: @insert_76_v4i16_uses2(
-; BE-NEXT: [[LO16:%.*]] = trunc i32 [[X:%.*]] to i16
-; BE-NEXT: call void @use(i16 [[LO16]])
-; BE-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[V:%.*]] to <4 x i32>
-; BE-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X]], i64 3
-; BE-NEXT: [[INS1:%.*]] = bitcast <4 x i32> [[TMP2]] to <8 x i16>
-; BE-NEXT: ret <8 x i16> [[INS1]]
-;
-; LE-LABEL: @insert_76_v4i16_uses2(
-; LE-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
-; LE-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
-; LE-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
-; LE-NEXT: call void @use(i16 [[LO16]])
-; LE-NEXT: [[INS0:%.*]] = insertelement <8 x i16> [[V:%.*]], i16 [[LO16]], i64 7
-; LE-NEXT: [[INS1:%.*]] = insertelement <8 x i16> [[INS0]], i16 [[HI16]], i64 6
-; LE-NEXT: ret <8 x i16> [[INS1]]
+; ALL-LABEL: @insert_76_v4i16_uses2(
+; ALL-NEXT: [[HI32:%.*]] = lshr i32 [[X:%.*]], 16
+; ALL-NEXT: [[HI16:%.*]] = trunc i32 [[HI32]] to i16
+; ALL-NEXT: [[LO16:%.*]] = trunc i32 [[X]] to i16
+; ALL-NEXT: call void @use(i16 [[LO16]])
+; ALL-NEXT: [[INS0:%.*]] = insertelement <8 x i16> [[V:%.*]], i16 [[LO16]], i64 7
+; ALL-NEXT: [[INS1:%.*]] = insertelement <8 x i16> [[INS0]], i16 [[HI16]], i64 6
+; ALL-NEXT: ret <8 x i16> [[INS1]]
;
%hi32 = lshr i32 %x, 16
%hi16 = trunc i32 %hi32 to i16
More information about the llvm-commits
mailing list