[llvm] [VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective (PR #81852)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 15 04:54:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>

Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free.

If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely.

Fixes https://github.com/llvm/llvm-project/issues/81469

---
Full diff: https://github.com/llvm/llvm-project/pull/81852.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+63) 
- (added) llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll (+91) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index f18711ba30b708..20fb9de7c75aa9 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -111,6 +111,7 @@ class VectorCombine {
   bool scalarizeLoadExtract(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
+  bool foldTruncFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
 
   void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1527,67 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+/// Determine if its more efficient to fold:
+///   reduce(trunc(x)) -> trunc(reduce(x)).
+bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+  auto *II = dyn_cast<IntrinsicInst>(&I);
+  if (!II)
+    return false;
+
+  unsigned ReductionOpc = 0;
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::vector_reduce_add:
+    ReductionOpc = Instruction::Add;
+    break;
+  case Intrinsic::vector_reduce_mul:
+    ReductionOpc = Instruction::Mul;
+    break;
+  case Intrinsic::vector_reduce_and:
+    ReductionOpc = Instruction::And;
+    break;
+  case Intrinsic::vector_reduce_or:
+    ReductionOpc = Instruction::Or;
+    break;
+  case Intrinsic::vector_reduce_xor:
+    ReductionOpc = Instruction::Xor;
+    break;
+  default:
+    return false;
+  }
+  Value *ReductionSrc = I.getOperand(0);
+
+  Value *TruncSrc;
+  if (!match(ReductionSrc, m_Trunc(m_OneUse(m_Value(TruncSrc)))))
+    return false;
+
+  auto *Trunc = cast<CastInst>(ReductionSrc);
+  auto *TruncTy = cast<VectorType>(TruncSrc->getType());
+  auto *ReductionTy = cast<VectorType>(ReductionSrc->getType());
+  Type *ResultTy = I.getType();
+
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost OldCost =
+      TTI.getCastInstrCost(Instruction::Trunc, ReductionTy, TruncTy,
+                           TTI::CastContextHint::None, CostKind, Trunc) +
+      TTI.getArithmeticReductionCost(ReductionOpc, ReductionTy, std::nullopt,
+                                     CostKind);
+  InstructionCost NewCost =
+      TTI.getArithmeticReductionCost(ReductionOpc, TruncTy, std::nullopt,
+                                     CostKind) +
+      TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+                           ReductionTy->getScalarType(),
+                           TTI::CastContextHint::None, CostKind);
+
+  if (OldCost < NewCost || !NewCost.isValid())
+    return false;
+
+  Value *NewReduction = Builder.CreateIntrinsic(
+      TruncTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
+  Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
+  replaceValue(I, *NewTruncation);
+  return true;
+}
+
 /// This method looks for groups of shuffles acting on binops, of the form:
 ///  %x = shuffle ...
 ///  %y = shuffle ...
@@ -1917,6 +1979,7 @@ bool VectorCombine::run() {
       switch (Opcode) {
       case Instruction::Call:
         MadeChange |= foldShuffleFromReductions(I);
+        MadeChange |= foldTruncFromReductions(I);
         break;
       case Instruction::ICmp:
       case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
new file mode 100644
index 00000000000000..54d8250cdec816
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64    | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v2 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v3 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s --check-prefixes=CHECK,AVX512
+
+;
+; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
+;
+
+; Cheap AVX512 v8i64 -> v8i32 truncation
+define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0)  {
+; X64-LABEL: @reduce_add_trunc_v8i64_i32(
+; X64-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
+; X64-NEXT:    [[RED:%.*]] = trunc i64 [[TMP1]] to i32
+; X64-NEXT:    ret i32 [[RED]]
+;
+; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
+; AVX512-NEXT:    [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
+; AVX512-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; AVX512-NEXT:    ret i32 [[RED]]
+;
+  %tr = trunc <8 x i64> %a0 to <8 x i32>
+  %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+  ret i32 %red
+}
+declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
+
+; No legal vXi8 multiplication so vXi16 is always cheaper
+define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0)  {
+; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i16 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i16> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
+
+define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0)  {
+; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i32 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <8 x i32> %a0 to <8 x i8>
+  %red = tail call i8 @llvm.vector.reduce.or.v8i32(<8 x i8> %tr)
+  ret i8 %red
+}
+declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
+
+define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0)  {
+; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i64 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i64> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
+
+; Negative Test: vXi16 multiply is much cheaper than vXi64
+define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0)  {
+; CHECK-LABEL: @reduce_mul_trunc_v8i64_i16(
+; CHECK-NEXT:    [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i16>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> [[TR]])
+; CHECK-NEXT:    ret i16 [[RED]]
+;
+  %tr = trunc <8 x i64> %a0 to <8 x i16>
+  %red = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> %tr)
+  ret i16 %red
+}
+declare i16 @llvm.vector.reduce.mul.v8i16(<8 x i16>)
+
+; Negative Test: min/max reductions can't use pre-truncated types.
+define i8 @reduce_smin_trunc_v16i16_i8(<16 x i16> %a0)  {
+; CHECK-LABEL: @reduce_smin_trunc_v16i16_i8(
+; CHECK-NEXT:    [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i16> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.smin.v16i8(<16 x i8>)
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/81852


More information about the llvm-commits mailing list