[llvm] VectorWiden pass to widen aleady vectorized instrctions (PR #67029)

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 8 12:20:03 PDT 2023


https://github.com/dtemirbulatov updated https://github.com/llvm/llvm-project/pull/67029

>From 1191edc18daf0ad3827f5893c73f6ac23991bbe6 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <dinar.temirbulatov at arm.com>
Date: Thu, 21 Sep 2023 13:22:46 +0000
Subject: [PATCH] This pass allows us to widen already vectorized instructions
 to wider vector types. We encountered an issue with the current
 auto-vectorisations passes that would not allow us to implement the required
 functionality without a new pass easily. For example, SME2 ADD instruction
 with a first operand and the resulting register are in a multivector form
 with scalable vector types, while the third operand is just a regular
 scalable vector type:

add     { z4.s, z5.s }, { z4.s, z5.s }, z3.s

With the loop-vectorizer pass, choosing a correct VF to deal with one of
the operands and the result to be a wider vector type could be difficult.
With the new pass, we want to consider a group of operations, not a single one
like SLP does, to make more profitable transformations, including, for example,
LOADs and  STORES, arithmetical operations, etc. For example, we could combine
those independent ADD operations in a single basic block into one on ARM SVE:
typedef int v4si __attribute__ ((vector_size (16)));

void add(v4si *ptr, v4si *ptr1) {
  v4si a = *ptr;
  ptr++;
  v4si b = *ptr;
  ptr++;
  v4si c = *ptr;
  ptr++;
  v4si d = *ptr;
  *ptr1 = a+b;
  ptr1++;
  *ptr1 = c+d;
}

On ARM SVE hardware, we could produce:
        ptrue   p0.s, vl8
        mov     x8, #8             // =0x8
        ld1w    { z0.s }, p0/z, [x0]
        ld1w    { z1.s }, p0/z, [x0, x8, lsl #2]
        add     z0.s, z1.s, z0.s
        st1w    { z0.s }, p0, [x1]
        ret

Currently, we have this output https://godbolt.org/z/z5n78TWsc:
        ldp     q0, q1, [x0]
        ldp     q2, q3, [x0, #32]
        add     v0.4s, v1.4s, v0.4s
        add     v1.4s, v3.4s, v2.4s
        stp     q0, q1, [x1]
        ret

I noticed similar opportunities with SLP vectorizer not choosing wider VF due to
its implementation, for example, with reductions only able to handle four or
fewer elements width types. Currently, the pass supports only ADD and FP_ROUND
operations to widen and only in scalable vector types.
---
 .../llvm/Analysis/TargetTransformInfo.h       |  11 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   5 +
 .../llvm/Transforms/Vectorize/VectorWiden.h   |  25 ++
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 +
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../AArch64/AArch64TargetTransformInfo.cpp    |   9 +
 .../AArch64/AArch64TargetTransformInfo.h      |  21 ++
 llvm/lib/Transforms/Vectorize/CMakeLists.txt  |   1 +
 llvm/lib/Transforms/Vectorize/VectorWiden.cpp | 342 ++++++++++++++++++
 llvm/test/Transforms/VectorWiden/add.ll       |  37 ++
 .../Transforms/VectorWiden/fptrunc-bad-dep.ll |  45 +++
 llvm/test/Transforms/VectorWiden/fptrunc.ll   |  40 ++
 .../llvm/lib/Transforms/Vectorize/BUILD.gn    |   1 +
 14 files changed, 544 insertions(+)
 create mode 100644 llvm/include/llvm/Transforms/Vectorize/VectorWiden.h
 create mode 100644 llvm/lib/Transforms/Vectorize/VectorWiden.cpp
 create mode 100644 llvm/test/Transforms/VectorWiden/add.ll
 create mode 100644 llvm/test/Transforms/VectorWiden/fptrunc-bad-dep.ll
 create mode 100644 llvm/test/Transforms/VectorWiden/fptrunc.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..bbce3189ce9d7a4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1672,6 +1672,10 @@ class TargetTransformInfo {
   /// \return The maximum number of function arguments the target supports.
   unsigned getMaxNumArgs() const;
 
+  /// \returns Whether vector operations are a good candidate for vector widen.
+  bool considerToWiden(LLVMContext &Context,
+                           ArrayRef<Instruction *> IL) const;
+
   /// @}
 
 private:
@@ -2041,6 +2045,8 @@ class TargetTransformInfo::Concept {
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
   virtual unsigned getMaxNumArgs() const = 0;
+  virtual bool considerToWiden(LLVMContext &Context,
+                                   ArrayRef<Instruction *> IL) const = 0;
 };
 
 template <typename T>
@@ -2757,6 +2763,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxNumArgs() const override {
     return Impl.getMaxNumArgs();
   }
+
+  bool considerToWiden(LLVMContext &Context,
+                           ArrayRef<Instruction *> IL) const override {
+    return Impl.considerToWiden(Context, IL);
+  }
 };
 
 template <typename T>
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..6caa3fa6f2eb971 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -895,6 +895,11 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxNumArgs() const { return UINT_MAX; }
 
+  bool considerToWiden(LLVMContext &Context,
+                           ArrayRef<Instruction *> IL) const {
+    return false;
+  }
+
 protected:
   // Obtain the minimum required size to hold the value (without the sign)
   // In case of a vector it returns the min required size for one element.
diff --git a/llvm/include/llvm/Transforms/Vectorize/VectorWiden.h b/llvm/include/llvm/Transforms/Vectorize/VectorWiden.h
new file mode 100644
index 000000000000000..6988785a92ce09c
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/VectorWiden.h
@@ -0,0 +1,25 @@
+//===--- VectorWiden.h - Combining Vector Operations to wider types ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_VECTORWIDENING_H
+#define LLVM_TRANSFORMS_VECTORIZE_VECTORWIDENING_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class VectorWidenPass : public PassInfoMixin<VectorWidenPass> {
+public:
+  VectorWidenPass() {}
+
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_VECTORWIDENING_H
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..93302427ee8223a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1248,6 +1248,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::considerToWiden(
+    LLVMContext &Context, ArrayRef<Instruction *> IL) const {
+  return TTIImpl->considerToWiden(Context, IL);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 985ff88139323c6..45c2529f92a5881 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -263,6 +263,7 @@
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 #include "llvm/Transforms/Vectorize/SLPVectorizer.h"
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/Transforms/Vectorize/VectorWiden.h"
 #include <optional>
 
 using namespace llvm;
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index df9f14920f29161..2eef2f0a22d95d7 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -428,6 +428,7 @@ FUNCTION_PASS("tailcallelim", TailCallElimPass())
 FUNCTION_PASS("typepromotion", TypePromotionPass(TM))
 FUNCTION_PASS("unify-loop-exits", UnifyLoopExitsPass())
 FUNCTION_PASS("vector-combine", VectorCombinePass())
+FUNCTION_PASS("vector-widen", VectorWidenPass())
 FUNCTION_PASS("verify", VerifierPass())
 FUNCTION_PASS("verify<domtree>", DominatorTreeVerifierPass())
 FUNCTION_PASS("verify<loops>", LoopVerifierPass())
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cded28054f59259..1da184b48ec5370 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2426,6 +2426,15 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
             CostKind, I));
   }
 
+  static const TypeConversionCostTblEntry SME2Tbl[] = {
+    { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 1 }
+  };
+
+  if (ST->hasSME2())
+    if (const auto *Entry = ConvertCostTableLookup(
+        SME2Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
+      return AdjustCost(Entry->Cost);
+
   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
                                                  DstTy.getSimpleVT(),
                                                  SrcTy.getSimpleVT()))
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..140dd5014756ace 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -412,6 +412,27 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
     return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
   }
+
+  bool considerToWiden(LLVMContext &Context,
+                           ArrayRef<Instruction *> IL) const {
+    unsigned Opcode = IL[0]->getOpcode();
+    Type *Ty = IL[0]->getType();
+    if (!ST->hasSME2())
+      return false;
+    if (llvm::any_of(IL, [Opcode, Ty](Instruction *I) {
+          return (Opcode != I->getOpcode() || Ty != I->getType());
+        }))
+      return false;
+    if (Opcode == Instruction::FPTrunc &&
+        Ty == ScalableVectorType::get(Type::getHalfTy(Context), 4))
+      return true;
+    if (Opcode == Instruction::Add &&
+        Ty == ScalableVectorType::get(Type::getInt32Ty(Context), 4) &&
+        (IL[0]->getOperand(1) == IL[1]->getOperand(1) ||
+         IL[0]->getOperand(0) == IL[1]->getOperand(0)))
+      return true;
+    return false;
+  }
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 998dfd956575d3c..a1537bb1ffa632e 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -5,6 +5,7 @@ add_llvm_component_library(LLVMVectorize
   SLPVectorizer.cpp
   Vectorize.cpp
   VectorCombine.cpp
+  VectorWiden.cpp
   VPlan.cpp
   VPlanHCFGBuilder.cpp
   VPlanRecipes.cpp
diff --git a/llvm/lib/Transforms/Vectorize/VectorWiden.cpp b/llvm/lib/Transforms/Vectorize/VectorWiden.cpp
new file mode 100644
index 000000000000000..5e20e57938c7dbe
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VectorWiden.cpp
@@ -0,0 +1,342 @@
+///==--- VectorWiden.cpp - Combining Vector Operations to wider types ----==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tries to widen vector operations to a wider type, it finds
+// independent from each other operations with a certain vector type as SLP does
+// with scalars by Bottom Up. It detects consecutive stores that can be put
+// together into a wider vector-stores. Next, it attempts to construct
+// vectorizable tree using the use-def chains.
+//
+//==------------------------------------------------------------------------==//
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/CodeMoverUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Vectorize/VectorWiden.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "vector-widen"
+
+// Due to independant operations to widening that we consider with possibility
+// to merge those operations into one and also to widening store if we find
+// later store instructions. We have to consider the distance between those
+// independent operations or we might introduce bad register pressure, etc.
+
+static cl::opt<unsigned>
+    MaxInstDistance("vw-max-instr-distance", cl::init(30), cl::Hidden,
+                    cl::desc("Maximum distance between instructions to"
+                             "consider to widen"));
+
+namespace {
+class VectorWiden {
+public:
+  using InstrList = SmallVector<Instruction *, 2>;
+  using ValueList = SmallVector<Value *, 2>;
+  VectorWiden(Function &F, const TargetTransformInfo &TTI, DominatorTree &DT)
+      : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {}
+
+  bool run();
+
+private:
+  Function &F;
+  IRBuilder<> Builder;
+  const TargetTransformInfo &TTI;
+  DominatorTree &DT;
+  TargetLibraryInfo *TLI;
+
+  DenseSet<Instruction *> DeletedInstructions;
+
+  /// Checks if the instruction is marked for deletion.
+  bool isDeleted(Instruction *I) const { return DeletedInstructions.count(I); }
+
+  /// Removes an instruction from its block and eventually deletes it.
+  void eraseInstruction(Instruction *I) { DeletedInstructions.insert(I); }
+
+  bool processBB(BasicBlock &BB, LLVMContext &Context);
+
+  bool canWidenNode(ArrayRef<Instruction *> IL, LLVMContext &Context);
+
+  bool widenNode(ArrayRef<Instruction *> IL, LLVMContext &Context);
+
+  void widenCastInst(ArrayRef<Instruction *> IL);
+
+  void widenBinaryOperator(ArrayRef<Instruction *> IL, bool Reorder);
+
+  InstructionCost getOpCost(unsigned Opcode, Type *To, Type *From,
+                            Instruction *I);
+};
+} // namespace
+
+void VectorWiden::widenCastInst(ArrayRef<Instruction *> IL) {
+  Instruction *I = IL[0];
+  Instruction *I1 = IL[1];
+  VectorType *RetOrigType = cast<VectorType>(I->getType());
+  VectorType *OrigType = cast<VectorType>(I->getOperand(0)->getType());
+  VectorType *RetType = VectorType::getDoubleElementsVectorType(RetOrigType);
+  VectorType *OpType = VectorType::getDoubleElementsVectorType(OrigType);
+  Value *WideVec = UndefValue::get(OpType);
+  Builder.SetInsertPoint(I);
+  Function *InsertIntr = llvm::Intrinsic::getDeclaration(
+      F.getParent(), Intrinsic::vector_insert, {OpType, OrigType});
+  Value *Insert1 = Builder.CreateCall(
+      InsertIntr, {WideVec, I->getOperand(0), Builder.getInt64(0)});
+  Value *Insert2 = Builder.CreateCall(
+      InsertIntr, {Insert1, I1->getOperand(0), Builder.getInt64(4)});
+  Value *ResCast = Builder.CreateCast(Instruction::CastOps(I->getOpcode()),
+                                      Insert2, RetType);
+  Function *ExtractIntr = llvm::Intrinsic::getDeclaration(
+      F.getParent(), Intrinsic::vector_extract, {RetOrigType, RetType});
+  if (!I->users().empty()) {
+    Value *Res =
+        Builder.CreateCall(ExtractIntr, {ResCast, Builder.getInt64(4)});
+    I->replaceAllUsesWith(Res);
+  }
+  if (!I1->users().empty()) {
+    Value *Res2 =
+        Builder.CreateCall(ExtractIntr, {ResCast, Builder.getInt64(0)});
+    I1->replaceAllUsesWith(Res2);
+  }
+}
+
+void VectorWiden::widenBinaryOperator(ArrayRef<Instruction *> IL,
+                                      bool Reorder) {
+  Instruction *I = IL[0];
+  Instruction *I1 = IL[1];
+
+  Value *XHi = I->getOperand(0);
+  Value *XLo = I1->getOperand(0);
+  Value *YHi = I->getOperand(1);
+  Value *YLo = I1->getOperand(1);
+  if (Reorder) {
+    std::swap(XHi, YHi);
+    std::swap(XLo, YLo);
+  }
+
+  VectorType *RetOrigType = cast<VectorType>(I->getType());
+  VectorType *OrigType = cast<VectorType>(I->getOperand(0)->getType());
+  VectorType *RetType = VectorType::getDoubleElementsVectorType(RetOrigType);
+  VectorType *OpType = VectorType::getDoubleElementsVectorType(OrigType);
+  Value *WideVec = UndefValue::get(OpType);
+  Builder.SetInsertPoint(I);
+  Function *InsertIntr = llvm::Intrinsic::getDeclaration(
+      F.getParent(), Intrinsic::vector_insert, {OpType, OrigType});
+  Value *X1 =
+      Builder.CreateCall(InsertIntr, {WideVec, XHi, Builder.getInt64(0)});
+  Value *X2 = Builder.CreateCall(InsertIntr, {X1, XLo, Builder.getInt64(4)});
+  Value *Y1 =
+      Builder.CreateCall(InsertIntr, {WideVec, YHi, Builder.getInt64(0)});
+  Value *Y2 = Builder.CreateCall(InsertIntr, {Y1, YLo, Builder.getInt64(4)});
+  Value *ResBinOp =
+      Builder.CreateBinOp((Instruction::BinaryOps)I->getOpcode(), X2, Y2);
+  Function *ExtractIntr = llvm::Intrinsic::getDeclaration(
+      F.getParent(), Intrinsic::vector_extract, {RetOrigType, RetType});
+  if (!I->users().empty()) {
+    Value *Res =
+        Builder.CreateCall(ExtractIntr, {ResBinOp, Builder.getInt64(0)});
+    I->replaceAllUsesWith(Res);
+  }
+  if (!I1->users().empty()) {
+    Value *Res2 =
+        Builder.CreateCall(ExtractIntr, {ResBinOp, Builder.getInt64(4)});
+    I1->replaceAllUsesWith(Res2);
+  }
+}
+
+bool VectorWiden::canWidenNode(ArrayRef<Instruction *> IL,
+                               LLVMContext &Context) {
+  if (!TTI.considerToWiden(Context, IL))
+    return false;
+
+  for (int X = 0, E = IL.size(); X < E; X++) {
+    for (int Y = 0, E = IL.size(); Y < E; Y++) {
+      if (X == Y)
+        continue;
+      if ((IL[X] == IL[Y]) || (IL[X]->getOpcode() != IL[Y]->getOpcode()) ||
+          // Ignore if any live in a diffrent Basic Block
+          (IL[X]->getParent() != IL[Y]->getParent()) ||
+          // Ignore if disatance between two are too apart.
+          (abs(std::distance(IL[Y]->getIterator(), IL[X]->getIterator())) >
+           MaxInstDistance) ||
+          (IL[X]->getOperand(0) == IL[Y] ||
+           (IL[X]->getNumOperands() > 1 && IL[X]->getOperand(1) == IL[Y])))
+        return false;
+    }
+    if (isDeleted(IL[X]) || !IL[X]->hasOneUse())
+      return false;
+    if (IL[0]->getParent() == IL[X]->user_back()->getParent() &&
+        DT.dominates(IL[X]->user_back(), IL[0]))
+      return false;
+  }
+  return true;
+}
+
+bool VectorWiden::widenNode(ArrayRef<Instruction *> IL, LLVMContext &Context) {
+  LLVM_DEBUG(dbgs() << "VW: widenNode: " << *IL[0] << " " << *IL[1] << "\n");
+  assert(IL.size() == 2 && "Incorrect instructions list to widen.");
+  if (!canWidenNode(IL, Context))
+    return false;
+  if (dyn_cast<CastInst>(IL[0])) {
+    VectorType *RetOrigType = cast<VectorType>(IL[0]->getType());
+    VectorType *OrigType = cast<VectorType>(IL[0]->getOperand(0)->getType());
+    InstructionCost Cost =
+        getOpCost(Instruction::FPTrunc, RetOrigType, OrigType, IL[0]);
+    VectorType *RetType = VectorType::getDoubleElementsVectorType(RetOrigType);
+    VectorType *OpType = VectorType::getDoubleElementsVectorType(OrigType);
+    InstructionCost CostNew =
+        getOpCost(Instruction::FPTrunc, RetType, OpType, IL[0]);
+    if (2 * Cost < CostNew)
+      return false;
+    LLVM_DEBUG(dbgs() << "VW: Decided to widen CastInst, safe to merge : "
+                      << *IL[0] << " with  " << *IL[1] << "\n");
+    widenCastInst(IL);
+    return true;
+  }
+  if (dyn_cast<BinaryOperator>(IL[0])) {
+    VectorType *OrigType = cast<VectorType>(IL[0]->getOperand(0)->getType());
+    VectorType *OpType = VectorType::getDoubleElementsVectorType(OrigType);
+    InstructionCost Cost =
+        getOpCost(Instruction::Add, OrigType, OrigType, IL[0]);
+    InstructionCost CostNew =
+        getOpCost(Instruction::Add, OpType, OpType, IL[0]);
+    if (2 * Cost < CostNew)
+      return false;
+    LLVM_DEBUG(dbgs() << "VW: Decided to widen BinaryOp, safe to merge : "
+                      << *IL[0] << " with  " << *IL[1] << "\n");
+    widenBinaryOperator(IL, IL[0]->getOperand(1) != IL[1]->getOperand(1));
+    return true;
+  }
+  return false;
+}
+
+InstructionCost VectorWiden::getOpCost(unsigned Opcode, Type *To, Type *From,
+                                       Instruction *I) {
+  InstructionCost Cost = 0;
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  switch (Opcode) {
+  case Instruction::FPTrunc: {
+    Cost = TTI.getCastInstrCost(Opcode, To, From, TTI::getCastContextHint(I),
+                                CostKind, I);
+    break;
+  }
+  case Instruction::Add: {
+    unsigned OpIdx = isa<UnaryOperator>(I) ? 0 : 1;
+    TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(I->getOperand(0));
+    TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(I->getOperand(OpIdx));
+    SmallVector<const Value *> Operands(I->operand_values());
+    Cost = TTI.getArithmeticInstrCost(I->getOpcode(), To, CostKind, Op1Info,
+                                      Op2Info, Operands, I);
+    break;
+  }
+  default:
+    llvm_unreachable("Unknown instruction");
+  }
+  return Cost;
+}
+
+bool VectorWiden::processBB(BasicBlock &BB, LLVMContext &Context) {
+  DenseMap<unsigned, std::pair<InstrList, unsigned>> Operations;
+  unsigned Counter = 0;
+  for (BasicBlock::reverse_iterator IP(BB.rbegin()); IP != BB.rend();
+       *IP++, ++Counter) {
+    Instruction *I = &*IP;
+    unsigned OpFound = 0;
+
+    if (I->isDebugOrPseudoInst() || isDeleted(I))
+      continue;
+
+    if ((dyn_cast<BinaryOperator>(I) && I->getNumOperands() == 2) ||
+        (dyn_cast<CastInst>(I) &&
+         I->getOpcode() != Instruction::AddrSpaceCast &&
+         I->getNumOperands() == 1)) {
+      if (Operations.find(I->getOpcode()) != Operations.end()) {
+        auto *OpRec = &Operations[I->getOpcode()];
+        // If instructions are too apart then remove old instrction
+        // and reset position to this instruction.
+        if (Counter - Operations[I->getOpcode()].second > MaxInstDistance) {
+          OpRec->second = Counter;
+          OpRec->first.clear();
+          OpRec->first.push_back(I);
+        } else {
+          OpRec->first.push_back(I);
+          OpFound = I->getOpcode();
+        }
+      } else {
+        Operations[I->getOpcode()] = {{I}, Counter};
+      }
+    }
+
+    if (OpFound && Operations.find(OpFound) != Operations.end()) {
+      auto *OpRec = &Operations[OpFound];
+      for (Instruction *Op : OpRec->first)
+        LLVM_DEBUG(dbgs() << "VW Op to check : " << *Op << "\n");
+      if (!widenNode(OpRec->first, Context)) {
+        LLVM_DEBUG(dbgs() << "VW Unable to consturct the tree.\n");
+        OpRec->first.erase(OpRec->first.begin());
+        OpRec->second = Counter;
+      } else {
+        for (Instruction *Instr : OpRec->first)
+          eraseInstruction(Instr);
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool VectorWiden::run() {
+  bool Changed = false;
+  LLVMContext &Context = F.getContext();
+
+  LLVM_DEBUG(dbgs() << "VW Function:" << F.getName() << "\n");
+  for (BasicBlock &BB : F) {
+    LLVM_DEBUG(dbgs() << "VW BB:" << BB.getName() << "\n");
+
+    while (processBB(BB, Context))
+      Changed = true;
+  }
+
+  if (Changed)
+    for (auto *I : DeletedInstructions)
+      RecursivelyDeleteTriviallyDeadInstructions(I);
+
+  return Changed;
+}
+
+PreservedAnalyses VectorWidenPass::run(Function &F,
+                                       FunctionAnalysisManager &FAM) {
+  TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
+  DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+
+  VectorWiden VecWiden(F, TTI, DT);
+
+  if (!VecWiden.run())
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserveSet<CFGAnalyses>();
+  return PA;
+}
diff --git a/llvm/test/Transforms/VectorWiden/add.ll b/llvm/test/Transforms/VectorWiden/add.ll
new file mode 100644
index 000000000000000..4ef96a2c60c77ec
--- /dev/null
+++ b/llvm/test/Transforms/VectorWiden/add.ll
@@ -0,0 +1,37 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
+; RUN: opt -passes=vector-widen -mtriple aarch64-linux-gnu -mattr=+sme2 -S %s 2>&1 | FileCheck %s
+
+define void @add(ptr %ptr, ptr %ptr1) {
+; CHECK-LABEL: define void @add
+; CHECK-SAME: (ptr [[PTR:%.*]], ptr [[PTR1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load <vscale x 4 x i32>, ptr [[PTR]], align 16
+; CHECK-NEXT:    [[INCDEC_PTR:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[PTR]], i64 1
+; CHECK-NEXT:    [[TMP1:%.*]] = load <vscale x 4 x i32>, ptr [[INCDEC_PTR]], align 16
+; CHECK-NEXT:    [[INCDEC_PTR1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[PTR]], i64 2
+; CHECK-NEXT:    [[TMP2:%.*]] = load <vscale x 4 x i32>, ptr [[INCDEC_PTR1]], align 16
+; CHECK-NEXT:    [[TMP3:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> undef, <vscale x 4 x i32> [[TMP1]], i64 0)
+; CHECK-NEXT:    [[TMP4:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP3]], <vscale x 4 x i32> [[TMP0]], i64 4)
+; CHECK-NEXT:    [[TMP5:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> undef, <vscale x 4 x i32> [[TMP2]], i64 0)
+; CHECK-NEXT:    [[TMP6:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP5]], <vscale x 4 x i32> [[TMP2]], i64 4)
+; CHECK-NEXT:    [[TMP7:%.*]] = add <vscale x 8 x i32> [[TMP4]], [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP7]], i64 0)
+; CHECK-NEXT:    [[TMP9:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP7]], i64 4)
+; CHECK-NEXT:    store <vscale x 4 x i32> [[TMP9]], ptr [[PTR1]], align 16
+; CHECK-NEXT:    [[INCDEC_PTR3:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[PTR1]], i64 1
+; CHECK-NEXT:    store <vscale x 4 x i32> [[TMP8]], ptr [[INCDEC_PTR3]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = load <vscale x 4 x i32>, ptr %ptr, align 16
+  %incdec.ptr = getelementptr inbounds <vscale x 4 x i32>, ptr %ptr, i64 1
+  %1 = load <vscale x 4 x i32>, ptr %incdec.ptr, align 16
+  %incdec.ptr1 = getelementptr inbounds <vscale x 4 x i32>, ptr %ptr, i64 2
+  %2 = load <vscale x 4 x i32>, ptr %incdec.ptr1, align 16
+  %add = add <vscale x 4 x i32> %0, %2
+  %add4 = add <vscale x 4 x i32> %1, %2
+  store <vscale x 4 x i32> %add, ptr %ptr1, align 16
+  %incdec.ptr3 = getelementptr inbounds <vscale x 4 x i32>, ptr %ptr1, i64 1
+  store <vscale x 4 x i32> %add4, ptr %incdec.ptr3, align 16
+  ret void
+}
diff --git a/llvm/test/Transforms/VectorWiden/fptrunc-bad-dep.ll b/llvm/test/Transforms/VectorWiden/fptrunc-bad-dep.ll
new file mode 100644
index 000000000000000..9f00d5ca113f4f0
--- /dev/null
+++ b/llvm/test/Transforms/VectorWiden/fptrunc-bad-dep.ll
@@ -0,0 +1,45 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=vector-widen -mtriple aarch64-linux-gnu -mattr=+sme2 -S 2>&1 | FileCheck %s
+
+define void @fptrunc(ptr %ptr, ptr %ptr1) {
+; CHECK-LABEL: @fptrunc(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = shl nuw nsw i64 [[TMP3]], 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[PTR:%.*]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds float, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT:    [[WIDE_LOAD9:%.*]] = load <vscale x 4 x float>, ptr [[TMP5]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = fptrunc <vscale x 4 x float> [[WIDE_LOAD]] to <vscale x 4 x half>
+; CHECK-NEXT:    [[EXTR:%.*]] = call <vscale x 1 x half> @llvm.vector.extract.nxv1f16.nxv4f16(<vscale x 4 x half> [[TMP6]], i64 0)
+; CHECK-NEXT:    [[EXTEND:%.*]] = fpext <vscale x 1 x half> [[EXTR]] to <vscale x 1 x float>
+; CHECK-NEXT:    [[INS:%.*]] = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.nxv1f32(<vscale x 4 x float> [[WIDE_LOAD9]], <vscale x 1 x float> [[EXTEND]], i64 0)
+; CHECK-NEXT:    [[TMP7:%.*]] = fptrunc <vscale x 4 x float> [[INS]] to <vscale x 4 x half>
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds half, ptr [[PTR1:%.*]], i64 0
+; CHECK-NEXT:    store <vscale x 4 x half> [[TMP6]], ptr [[TMP8]], align 2
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds half, ptr [[TMP8]], i64 [[TMP4]]
+; CHECK-NEXT:    store <vscale x 4 x half> [[TMP7]], ptr [[TMP9]], align 2
+; CHECK-NEXT:    ret void
+;
+  %1 = tail call i64 @llvm.vscale.i64()
+  %2 = shl nuw nsw i64 %1, 2
+  %3 = tail call i64 @llvm.vscale.i64()
+  %4 = shl nuw nsw i64 %3, 2
+  %wide.load = load <vscale x 4 x float>, ptr %ptr, align 4
+  %5 = getelementptr inbounds float, ptr %ptr, i64 %2
+  %wide.load9 = load <vscale x 4 x float>, ptr %5, align 4
+  %6 = fptrunc <vscale x 4 x float> %wide.load to <vscale x 4 x half>
+  %extr = call <vscale x 1 x half> @llvm.vector.extract.nxv1f16.nxv4f16(<vscale x 4 x half> %6, i64 0)
+  %extend = fpext <vscale x 1 x half> %extr to <vscale x 1 x float>
+  %ins = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.nxv1f32(<vscale x 4 x float> %wide.load9, <vscale x 1 x float> %extend, i64 0)
+  %7 = fptrunc <vscale x 4 x float> %ins to <vscale x 4 x half>
+  %8 = getelementptr inbounds half, ptr %ptr1, i64 0
+  store <vscale x 4 x half> %6, ptr %8, align 2
+  %9 = getelementptr inbounds half, ptr %8, i64 %4
+  store <vscale x 4 x half> %7, ptr %9, align 2
+  ret void
+}
+
+declare i64 @llvm.vscale.i64()
+declare <vscale x 1 x half> @llvm.vector.extract.nxv1f16.nxv4f16(<vscale x 4 x half>, i64 immarg)
+declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.nxv1f32(<vscale x 4 x float>, <vscale x 1 x float>, i64 immarg)
diff --git a/llvm/test/Transforms/VectorWiden/fptrunc.ll b/llvm/test/Transforms/VectorWiden/fptrunc.ll
new file mode 100644
index 000000000000000..2d34d25e6d463c0
--- /dev/null
+++ b/llvm/test/Transforms/VectorWiden/fptrunc.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=vector-widen -mtriple aarch64-linux-gnu -mattr=+sme2 -S 2>&1 | FileCheck %s
+
+define void @fptrunc(ptr %ptr, ptr %ptr1) {
+; CHECK-LABEL: @fptrunc(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = shl nuw nsw i64 [[TMP3]], 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[PTR:%.*]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds float, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT:    [[WIDE_LOAD9:%.*]] = load <vscale x 4 x float>, ptr [[TMP5]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = call <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> undef, <vscale x 4 x float> [[WIDE_LOAD9]], i64 0)
+; CHECK-NEXT:    [[TMP7:%.*]] = call <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> [[TMP6]], <vscale x 4 x float> [[WIDE_LOAD]], i64 4)
+; CHECK-NEXT:    [[TMP8:%.*]] = fptrunc <vscale x 8 x float> [[TMP7]] to <vscale x 8 x half>
+; CHECK-NEXT:    [[TMP9:%.*]] = call <vscale x 4 x half> @llvm.vector.extract.nxv4f16.nxv8f16(<vscale x 8 x half> [[TMP8]], i64 4)
+; CHECK-NEXT:    [[TMP10:%.*]] = call <vscale x 4 x half> @llvm.vector.extract.nxv4f16.nxv8f16(<vscale x 8 x half> [[TMP8]], i64 0)
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[PTR1:%.*]], i64 0
+; CHECK-NEXT:    store <vscale x 4 x half> [[TMP10]], ptr [[TMP11]], align 2
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds half, ptr [[TMP11]], i64 [[TMP4]]
+; CHECK-NEXT:    store <vscale x 4 x half> [[TMP9]], ptr [[TMP12]], align 2
+; CHECK-NEXT:    ret void
+;
+  %1 = tail call i64 @llvm.vscale.i64()
+  %2 = shl nuw nsw i64 %1, 2
+  %3 = tail call i64 @llvm.vscale.i64()
+  %4 = shl nuw nsw i64 %3, 2
+  %wide.load = load <vscale x 4 x float>, ptr %ptr, align 4
+  %5 = getelementptr inbounds float, ptr %ptr, i64 %2
+  %wide.load9 = load <vscale x 4 x float>, ptr %5, align 4
+  %6 = fptrunc <vscale x 4 x float> %wide.load to <vscale x 4 x half>
+  %7 = fptrunc <vscale x 4 x float> %wide.load9 to <vscale x 4 x half>
+  %8 = getelementptr inbounds half, ptr %ptr1, i64 0
+  store <vscale x 4 x half> %6, ptr %8, align 2
+  %9 = getelementptr inbounds half, ptr %8, i64 %4
+  store <vscale x 4 x half> %7, ptr %9, align 2
+  ret void
+}
+
+declare i64 @llvm.vscale.i64()
diff --git a/llvm/utils/gn/secondary/llvm/lib/Transforms/Vectorize/BUILD.gn b/llvm/utils/gn/secondary/llvm/lib/Transforms/Vectorize/BUILD.gn
index ca67426e08699ba..f5ef7bbd7106a55 100644
--- a/llvm/utils/gn/secondary/llvm/lib/Transforms/Vectorize/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/lib/Transforms/Vectorize/BUILD.gn
@@ -19,5 +19,6 @@ static_library("Vectorize") {
     "VPlanVerifier.cpp",
     "VectorCombine.cpp",
     "Vectorize.cpp",
+    "VectorWiden.cpp",
   ]
 }



More information about the llvm-commits mailing list