[llvm] e5a32d7 - [InstCombine] move extend after insertelement if both operands are extended

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 15 11:38:55 PDT 2021


Author: Sanjay Patel
Date: 2021-09-15T14:38:03-04:00
New Revision: e5a32d720ef2d8989442a533e1dd2d7e667155c1

URL: https://github.com/llvm/llvm-project/commit/e5a32d720ef2d8989442a533e1dd2d7e667155c1
DIFF: https://github.com/llvm/llvm-project/commit/e5a32d720ef2d8989442a533e1dd2d7e667155c1.diff

LOG: [InstCombine] move extend after insertelement if both operands are extended

I was wondering how instcombine does on the examples in D109236,
and we're missing a basic transform:

inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index)

https://alive2.llvm.org/ce/z/z2aBu9

Note that there are several possible extensions of this fold
(see TODO comments).

Differential Revision: https://reviews.llvm.org/D109537

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
    llvm/test/Transforms/InstCombine/insert-ext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index c56026e0d5d8..47ab278faaa0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1407,6 +1407,41 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) {
   return nullptr;
 }
 
+/// If both the base vector and the inserted element are extended from the same
+/// type, do the insert element in the narrow source type followed by extend.
+/// TODO: This can be extended to include other cast opcodes, but particularly
+///       if we create a wider insertelement, make sure codegen is not harmed.
+static Instruction *narrowInsElt(InsertElementInst &InsElt,
+                                 InstCombiner::BuilderTy &Builder) {
+  // We are creating a vector extend. If the original vector extend has another
+  // use, that would mean we end up with 2 vector extends, so avoid that.
+  // TODO: We could ease the use-clause to "if at least one op has one use"
+  //       (assuming that the source types match - see next TODO comment).
+  Value *Vec = InsElt.getOperand(0);
+  if (!Vec->hasOneUse())
+    return nullptr;
+
+  Value *Scalar = InsElt.getOperand(1);
+  Value *X, *Y;
+  CastInst::CastOps CastOpcode;
+  if (match(Vec, m_FPExt(m_Value(X))) && match(Scalar, m_FPExt(m_Value(Y))))
+    CastOpcode = Instruction::FPExt;
+  else if (match(Vec, m_SExt(m_Value(X))) && match(Scalar, m_SExt(m_Value(Y))))
+    CastOpcode = Instruction::SExt;
+  else if (match(Vec, m_ZExt(m_Value(X))) && match(Scalar, m_ZExt(m_Value(Y))))
+    CastOpcode = Instruction::ZExt;
+  else
+    return nullptr;
+
+  // TODO: We can allow mismatched types by creating an intermediate cast.
+  if (X->getType()->getScalarType() != Y->getType())
+    return nullptr;
+
+  // inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index)
+  Value *NewInsElt = Builder.CreateInsertElement(X, Y, InsElt.getOperand(2));
+  return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType());
+}
+
 Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
   Value *VecOp    = IE.getOperand(0);
   Value *ScalarOp = IE.getOperand(1);
@@ -1526,6 +1561,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
   if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE))
     return IdentityShuf;
 
+  if (Instruction *Ext = narrowInsElt(IE, Builder))
+    return Ext;
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/insert-ext.ll b/llvm/test/Transforms/InstCombine/insert-ext.ll
index 2367171055d9..3eb6ab1e560b 100644
--- a/llvm/test/Transforms/InstCombine/insert-ext.ll
+++ b/llvm/test/Transforms/InstCombine/insert-ext.ll
@@ -6,9 +6,8 @@ declare void @usevec(<2 x i32>)
 
 define <2 x double> @fpext_fpext(<2 x half> %x, half %y, i32 %index) {
 ; CHECK-LABEL: @fpext_fpext(
-; CHECK-NEXT:    [[V:%.*]] = fpext <2 x half> [[X:%.*]] to <2 x double>
-; CHECK-NEXT:    [[S:%.*]] = fpext half [[Y:%.*]] to double
-; CHECK-NEXT:    [[I:%.*]] = insertelement <2 x double> [[V]], double [[S]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x half> [[X:%.*]], half [[Y:%.*]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[I:%.*]] = fpext <2 x half> [[TMP1]] to <2 x double>
 ; CHECK-NEXT:    ret <2 x double> [[I]]
 ;
   %v = fpext <2 x half> %x to <2 x double>
@@ -19,9 +18,8 @@ define <2 x double> @fpext_fpext(<2 x half> %x, half %y, i32 %index) {
 
 define <2 x i32> @sext_sext(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @sext_sext(
-; CHECK-NEXT:    [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[S:%.*]] = sext i8 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[I:%.*]] = insertelement <2 x i32> [[V]], i32 [[S]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y:%.*]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[I:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[I]]
 ;
   %v = sext <2 x i8> %x to <2 x i32>
@@ -32,9 +30,8 @@ define <2 x i32> @sext_sext(<2 x i8> %x, i8 %y, i32 %index) {
 
 define <2 x i12> @zext_zext(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @zext_zext(
-; CHECK-NEXT:    [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i12>
-; CHECK-NEXT:    [[S:%.*]] = zext i8 [[Y:%.*]] to i12
-; CHECK-NEXT:    [[I:%.*]] = insertelement <2 x i12> [[V]], i12 [[S]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y:%.*]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[I:%.*]] = zext <2 x i8> [[TMP1]] to <2 x i12>
 ; CHECK-NEXT:    ret <2 x i12> [[I]]
 ;
   %v = zext <2 x i8> %x to <2 x i12>
@@ -43,6 +40,8 @@ define <2 x i12> @zext_zext(<2 x i8> %x, i8 %y, i32 %index) {
   ret <2 x i12> %i
 }
 
+; negative test - need same source type
+
 define <2 x double> @fpext_fpext_types(<2 x half> %x, float %y, i32 %index) {
 ; CHECK-LABEL: @fpext_fpext_types(
 ; CHECK-NEXT:    [[V:%.*]] = fpext <2 x half> [[X:%.*]] to <2 x double>
@@ -56,6 +55,8 @@ define <2 x double> @fpext_fpext_types(<2 x half> %x, float %y, i32 %index) {
   ret <2 x double> %i
 }
 
+; negative test - need same source type
+
 define <2 x i32> @sext_sext_types(<2 x i16> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @sext_sext_types(
 ; CHECK-NEXT:    [[V:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
@@ -69,6 +70,8 @@ define <2 x i32> @sext_sext_types(<2 x i16> %x, i8 %y, i32 %index) {
   ret <2 x i32> %i
 }
 
+; negative test - need same extend opcode
+
 define <2 x i12> @sext_zext(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @sext_zext(
 ; CHECK-NEXT:    [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i12>
@@ -82,6 +85,8 @@ define <2 x i12> @sext_zext(<2 x i8> %x, i8 %y, i32 %index) {
   ret <2 x i12> %i
 }
 
+; negative test - don't trade scalar extend for vector extend
+
 define <2 x i32> @sext_sext_use1(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @sext_sext_use1(
 ; CHECK-NEXT:    [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
@@ -99,10 +104,10 @@ define <2 x i32> @sext_sext_use1(<2 x i8> %x, i8 %y, i32 %index) {
 
 define <2 x i32> @zext_zext_use2(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @zext_zext_use2(
-; CHECK-NEXT:    [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[S:%.*]] = zext i8 [[Y:%.*]] to i32
 ; CHECK-NEXT:    call void @use(i32 [[S]])
-; CHECK-NEXT:    [[I:%.*]] = insertelement <2 x i32> [[V]], i32 [[S]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y]], i32 [[INDEX:%.*]]
+; CHECK-NEXT:    [[I:%.*]] = zext <2 x i8> [[TMP1]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[I]]
 ;
   %v = zext <2 x i8> %x to <2 x i32>
@@ -112,6 +117,8 @@ define <2 x i32> @zext_zext_use2(<2 x i8> %x, i8 %y, i32 %index) {
   ret <2 x i32> %i
 }
 
+; negative test - don't create an extra extend
+
 define <2 x i32> @zext_zext_use3(<2 x i8> %x, i8 %y, i32 %index) {
 ; CHECK-LABEL: @zext_zext_use3(
 ; CHECK-NEXT:    [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>


        


More information about the llvm-commits mailing list