[llvm] [VPlan] Preserve trunc nuw/nsw in VPRecipeWithIRFlags (PR #144700)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 07:02:42 PDT 2025


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/144700

>From 81e4298d3668ac047a354889aa0d5298315f0228 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 18 Jun 2025 14:13:18 +0100
Subject: [PATCH 1/3] [VPlan][IR] Preserve trunc nuw/nsw in VPRecipeWithIRFlags

This preserves the nuw/nsw flags on widened truncs by splitting up OverflowingBinaryOperator into OverflowingOperator and OverflowingBinaryOperator.

We could probably go through the other users of OverflowignBinaryOperator and see if they could be generalized to OverflowingOperator later.

The motivation for this is to be able to fold away some redundant truncs feeding into uitofps (or potentially narrow the inductions feeding them)
---
 llvm/include/llvm/IR/Operator.h               | 34 ++++++++++++++-----
 llvm/lib/Transforms/Vectorize/VPlan.h         | 16 ++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  6 ++--
 .../AArch64/conditional-branches-cost.ll      |  4 +--
 .../ARM/mve-reg-pressure-vmla.ll              |  4 +--
 .../PhaseOrdering/ARM/arm_mult_q15.ll         |  2 +-
 6 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index 8344eaec807b3..13272ed2727ec 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -73,9 +73,9 @@ class Operator : public User {
 };
 
 /// Utility class for integer operators which may exhibit overflow - Add, Sub,
-/// Mul, and Shl. It does not include SDiv, despite that operator having the
-/// potential for overflow.
-class OverflowingBinaryOperator : public Operator {
+/// Mul, Shl and Trunc. It does not include SDiv, despite that operator having
+/// the potential for overflow.
+class OverflowingOperator : public Operator {
 public:
   enum {
     AnyWrap        = 0,
@@ -97,9 +97,6 @@ class OverflowingBinaryOperator : public Operator {
   }
 
 public:
-  /// Transparently provide more efficient getOperand methods.
-  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value);
-
   /// Test whether this operation is known to never
   /// undergo unsigned overflow, aka the nuw property.
   bool hasNoUnsignedWrap() const {
@@ -131,11 +128,32 @@ class OverflowingBinaryOperator : public Operator {
     return I->getOpcode() == Instruction::Add ||
            I->getOpcode() == Instruction::Sub ||
            I->getOpcode() == Instruction::Mul ||
-           I->getOpcode() == Instruction::Shl;
+           I->getOpcode() == Instruction::Shl ||
+           I->getOpcode() == Instruction::Trunc;
   }
   static bool classof(const ConstantExpr *CE) {
     return CE->getOpcode() == Instruction::Add ||
-           CE->getOpcode() == Instruction::Sub;
+           CE->getOpcode() == Instruction::Sub ||
+           CE->getOpcode() == Instruction::Trunc;
+  }
+  static bool classof(const Value *V) {
+    return (isa<Instruction>(V) && classof(cast<Instruction>(V))) ||
+           (isa<ConstantExpr>(V) && classof(cast<ConstantExpr>(V)));
+  }
+};
+
+/// The subset of OverflowingOperators that are also BinaryOperators.
+class OverflowingBinaryOperator : public OverflowingOperator {
+public:
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value);
+
+  static bool classof(const Instruction *I) {
+    return OverflowingOperator::classof(I) && isa<BinaryOperator>(I);
+  }
+  static bool classof(const ConstantExpr *CE) {
+    return OverflowingOperator::classof(CE) &&
+           Instruction::isBinaryOp(CE->getOpcode());
   }
   static bool classof(const Value *V) {
     return (isa<Instruction>(V) && classof(cast<Instruction>(V))) ||
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index f3306ad7cb8ec..8b9a03cb941e9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -595,7 +595,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
-    OverflowingBinOp,
+    OverflowingOp,
     DisjointOp,
     PossiblyExactOp,
     GEPOp,
@@ -661,8 +661,8 @@ class VPIRFlags {
     } else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
       OpType = OperationType::DisjointOp;
       DisjointFlags.IsDisjoint = Op->isDisjoint();
-    } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
-      OpType = OperationType::OverflowingBinOp;
+    } else if (auto *Op = dyn_cast<OverflowingOperator>(&I)) {
+      OpType = OperationType::OverflowingOp;
       WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
     } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
       OpType = OperationType::PossiblyExactOp;
@@ -686,7 +686,7 @@ class VPIRFlags {
       : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
   VPIRFlags(WrapFlagsTy WrapFlags)
-      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
+      : OpType(OperationType::OverflowingOp), WrapFlags(WrapFlags) {}
 
   VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
@@ -710,7 +710,7 @@ class VPIRFlags {
     // NOTE: This needs to be kept in-sync with
     // Instruction::dropPoisonGeneratingFlags.
     switch (OpType) {
-    case OperationType::OverflowingBinOp:
+    case OperationType::OverflowingOp:
       WrapFlags.HasNUW = false;
       WrapFlags.HasNSW = false;
       break;
@@ -739,7 +739,7 @@ class VPIRFlags {
   /// Apply the IR flags to \p I.
   void applyFlags(Instruction &I) const {
     switch (OpType) {
-    case OperationType::OverflowingBinOp:
+    case OperationType::OverflowingOp:
       I.setHasNoUnsignedWrap(WrapFlags.HasNUW);
       I.setHasNoSignedWrap(WrapFlags.HasNSW);
       break;
@@ -799,13 +799,13 @@ class VPIRFlags {
   }
 
   bool hasNoUnsignedWrap() const {
-    assert(OpType == OperationType::OverflowingBinOp &&
+    assert(OpType == OperationType::OverflowingOp &&
            "recipe doesn't have a NUW flag");
     return WrapFlags.HasNUW;
   }
 
   bool hasNoSignedWrap() const {
-    assert(OpType == OperationType::OverflowingBinOp &&
+    assert(OpType == OperationType::OverflowingOp &&
            "recipe doesn't have a NSW flag");
     return WrapFlags.HasNSW;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..96bb6bb0d37b5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1627,9 +1627,9 @@ VPIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
 #if !defined(NDEBUG)
 bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
   switch (OpType) {
-  case OperationType::OverflowingBinOp:
+  case OperationType::OverflowingOp:
     return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
-           Opcode == Instruction::Mul ||
+           Opcode == Instruction::Mul || Opcode == Instruction::Trunc ||
            Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
   case OperationType::DisjointOp:
     return Opcode == Instruction::Or;
@@ -1672,7 +1672,7 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
     if (ExactFlags.IsExact)
       O << " exact";
     break;
-  case OperationType::OverflowingBinOp:
+  case OperationType::OverflowingOp:
     if (WrapFlags.HasNUW)
       O << " nuw";
     if (WrapFlags.HasNSW)
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll b/llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll
index 976f95ff4f0ba..8bcfce8cc52c9 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll
@@ -1484,7 +1484,7 @@ define void @redundant_branch_and_tail_folding(ptr %dst, i1 %c) {
 ; DEFAULT-NEXT:    [[VEC_IND:%.*]] = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ]
 ; DEFAULT-NEXT:    [[STEP_ADD:%.*]] = add <4 x i64> [[VEC_IND]], splat (i64 4)
 ; DEFAULT-NEXT:    [[TMP0:%.*]] = add nuw nsw <4 x i64> [[STEP_ADD]], splat (i64 1)
-; DEFAULT-NEXT:    [[TMP1:%.*]] = trunc <4 x i64> [[TMP0]] to <4 x i32>
+; DEFAULT-NEXT:    [[TMP1:%.*]] = trunc nuw nsw <4 x i64> [[TMP0]] to <4 x i32>
 ; DEFAULT-NEXT:    [[TMP2:%.*]] = extractelement <4 x i32> [[TMP1]], i32 3
 ; DEFAULT-NEXT:    store i32 [[TMP2]], ptr [[DST]], align 4
 ; DEFAULT-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
@@ -1521,7 +1521,7 @@ define void @redundant_branch_and_tail_folding(ptr %dst, i1 %c) {
 ; PRED-NEXT:    [[VEC_IND:%.*]] = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[PRED_STORE_CONTINUE6]] ]
 ; PRED-NEXT:    [[TMP0:%.*]] = icmp ule <4 x i64> [[VEC_IND]], splat (i64 20)
 ; PRED-NEXT:    [[TMP1:%.*]] = add nuw nsw <4 x i64> [[VEC_IND]], splat (i64 1)
-; PRED-NEXT:    [[TMP2:%.*]] = trunc <4 x i64> [[TMP1]] to <4 x i32>
+; PRED-NEXT:    [[TMP2:%.*]] = trunc nuw nsw <4 x i64> [[TMP1]] to <4 x i32>
 ; PRED-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[TMP0]], i32 0
 ; PRED-NEXT:    br i1 [[TMP3]], label %[[PRED_STORE_IF:.*]], label %[[PRED_STORE_CONTINUE:.*]]
 ; PRED:       [[PRED_STORE_IF]]:
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reg-pressure-vmla.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reg-pressure-vmla.ll
index 4c29a3a0d1d01..0a751699a0549 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reg-pressure-vmla.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reg-pressure-vmla.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --filter-out-after "^scalar.ph:" --version 5
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --filter-out-after "^scalar.ph:" --version 5
 ; RUN: opt -mattr=+mve -passes=loop-vectorize < %s -S -o - | FileCheck %s
 
 target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
@@ -49,7 +49,7 @@ define void @fn(i32 noundef %n, ptr %in, ptr %out) #0 {
 ; CHECK-NEXT:    [[TMP10:%.*]] = add nuw nsw <4 x i32> [[TMP9]], [[TMP6]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = add nuw nsw <4 x i32> [[TMP10]], [[TMP8]]
 ; CHECK-NEXT:    [[TMP12:%.*]] = lshr <4 x i32> [[TMP11]], splat (i32 16)
-; CHECK-NEXT:    [[TMP13:%.*]] = trunc <4 x i32> [[TMP12]] to <4 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = trunc nuw <4 x i32> [[TMP12]] to <4 x i8>
 ; CHECK-NEXT:    [[TMP14:%.*]] = mul nuw nsw <4 x i32> [[TMP3]], splat (i32 32767)
 ; CHECK-NEXT:    [[TMP15:%.*]] = mul nuw <4 x i32> [[TMP5]], splat (i32 16762097)
 ; CHECK-NEXT:    [[TMP16:%.*]] = mul nuw <4 x i32> [[TMP7]], splat (i32 16759568)
diff --git a/llvm/test/Transforms/PhaseOrdering/ARM/arm_mult_q15.ll b/llvm/test/Transforms/PhaseOrdering/ARM/arm_mult_q15.ll
index 9032c363eb936..9d613b8fe456d 100644
--- a/llvm/test/Transforms/PhaseOrdering/ARM/arm_mult_q15.ll
+++ b/llvm/test/Transforms/PhaseOrdering/ARM/arm_mult_q15.ll
@@ -41,7 +41,7 @@ define void @arm_mult_q15(ptr %pSrcA, ptr %pSrcB, ptr noalias %pDst, i32 %blockS
 ; CHECK-NEXT:    [[TMP5:%.*]] = mul nsw <8 x i32> [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = ashr <8 x i32> [[TMP5]], splat (i32 15)
 ; CHECK-NEXT:    [[TMP7:%.*]] = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> [[TMP6]], <8 x i32> splat (i32 32767))
-; CHECK-NEXT:    [[TMP8:%.*]] = trunc <8 x i32> [[TMP7]] to <8 x i16>
+; CHECK-NEXT:    [[TMP8:%.*]] = trunc nsw <8 x i32> [[TMP7]] to <8 x i16>
 ; CHECK-NEXT:    store <8 x i16> [[TMP8]], ptr [[NEXT_GEP14]], align 2
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]

>From a4ed4881e1f1b5869c70254c23a80cf7938e0197 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 18 Jun 2025 15:07:27 +0100
Subject: [PATCH 2/3] Remove OverflowingOperator

---
 llvm/include/llvm/IR/Operator.h               | 34 +++++--------------
 llvm/lib/Transforms/Vectorize/VPlan.h         | 18 +++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  4 +--
 3 files changed, 19 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index 13272ed2727ec..8344eaec807b3 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -73,9 +73,9 @@ class Operator : public User {
 };
 
 /// Utility class for integer operators which may exhibit overflow - Add, Sub,
-/// Mul, Shl and Trunc. It does not include SDiv, despite that operator having
-/// the potential for overflow.
-class OverflowingOperator : public Operator {
+/// Mul, and Shl. It does not include SDiv, despite that operator having the
+/// potential for overflow.
+class OverflowingBinaryOperator : public Operator {
 public:
   enum {
     AnyWrap        = 0,
@@ -97,6 +97,9 @@ class OverflowingOperator : public Operator {
   }
 
 public:
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value);
+
   /// Test whether this operation is known to never
   /// undergo unsigned overflow, aka the nuw property.
   bool hasNoUnsignedWrap() const {
@@ -128,32 +131,11 @@ class OverflowingOperator : public Operator {
     return I->getOpcode() == Instruction::Add ||
            I->getOpcode() == Instruction::Sub ||
            I->getOpcode() == Instruction::Mul ||
-           I->getOpcode() == Instruction::Shl ||
-           I->getOpcode() == Instruction::Trunc;
+           I->getOpcode() == Instruction::Shl;
   }
   static bool classof(const ConstantExpr *CE) {
     return CE->getOpcode() == Instruction::Add ||
-           CE->getOpcode() == Instruction::Sub ||
-           CE->getOpcode() == Instruction::Trunc;
-  }
-  static bool classof(const Value *V) {
-    return (isa<Instruction>(V) && classof(cast<Instruction>(V))) ||
-           (isa<ConstantExpr>(V) && classof(cast<ConstantExpr>(V)));
-  }
-};
-
-/// The subset of OverflowingOperators that are also BinaryOperators.
-class OverflowingBinaryOperator : public OverflowingOperator {
-public:
-  /// Transparently provide more efficient getOperand methods.
-  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value);
-
-  static bool classof(const Instruction *I) {
-    return OverflowingOperator::classof(I) && isa<BinaryOperator>(I);
-  }
-  static bool classof(const ConstantExpr *CE) {
-    return OverflowingOperator::classof(CE) &&
-           Instruction::isBinaryOp(CE->getOpcode());
+           CE->getOpcode() == Instruction::Sub;
   }
   static bool classof(const Value *V) {
     return (isa<Instruction>(V) && classof(cast<Instruction>(V))) ||
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 8b9a03cb941e9..561485c54b67e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -595,7 +595,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
-    OverflowingOp,
+    WrappingOp,
     DisjointOp,
     PossiblyExactOp,
     GEPOp,
@@ -661,9 +661,9 @@ class VPIRFlags {
     } else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
       OpType = OperationType::DisjointOp;
       DisjointFlags.IsDisjoint = Op->isDisjoint();
-    } else if (auto *Op = dyn_cast<OverflowingOperator>(&I)) {
-      OpType = OperationType::OverflowingOp;
-      WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
+    } else if (isa<OverflowingBinaryOperator, TruncInst>(&I)) {
+      OpType = OperationType::WrappingOp;
+      WrapFlags = {I.hasNoUnsignedWrap(), I.hasNoSignedWrap()};
     } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
       OpType = OperationType::PossiblyExactOp;
       ExactFlags.IsExact = Op->isExact();
@@ -686,7 +686,7 @@ class VPIRFlags {
       : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
   VPIRFlags(WrapFlagsTy WrapFlags)
-      : OpType(OperationType::OverflowingOp), WrapFlags(WrapFlags) {}
+      : OpType(OperationType::WrappingOp), WrapFlags(WrapFlags) {}
 
   VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
@@ -710,7 +710,7 @@ class VPIRFlags {
     // NOTE: This needs to be kept in-sync with
     // Instruction::dropPoisonGeneratingFlags.
     switch (OpType) {
-    case OperationType::OverflowingOp:
+    case OperationType::WrappingOp:
       WrapFlags.HasNUW = false;
       WrapFlags.HasNSW = false;
       break;
@@ -739,7 +739,7 @@ class VPIRFlags {
   /// Apply the IR flags to \p I.
   void applyFlags(Instruction &I) const {
     switch (OpType) {
-    case OperationType::OverflowingOp:
+    case OperationType::WrappingOp:
       I.setHasNoUnsignedWrap(WrapFlags.HasNUW);
       I.setHasNoSignedWrap(WrapFlags.HasNSW);
       break;
@@ -799,13 +799,13 @@ class VPIRFlags {
   }
 
   bool hasNoUnsignedWrap() const {
-    assert(OpType == OperationType::OverflowingOp &&
+    assert(OpType == OperationType::WrappingOp &&
            "recipe doesn't have a NUW flag");
     return WrapFlags.HasNUW;
   }
 
   bool hasNoSignedWrap() const {
-    assert(OpType == OperationType::OverflowingOp &&
+    assert(OpType == OperationType::WrappingOp &&
            "recipe doesn't have a NSW flag");
     return WrapFlags.HasNSW;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 96bb6bb0d37b5..303f8e9f20688 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1627,7 +1627,7 @@ VPIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
 #if !defined(NDEBUG)
 bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
   switch (OpType) {
-  case OperationType::OverflowingOp:
+  case OperationType::WrappingOp:
     return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
            Opcode == Instruction::Mul || Opcode == Instruction::Trunc ||
            Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
@@ -1672,7 +1672,7 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
     if (ExactFlags.IsExact)
       O << " exact";
     break;
-  case OperationType::OverflowingOp:
+  case OperationType::WrappingOp:
     if (WrapFlags.HasNUW)
       O << " nuw";
     if (WrapFlags.HasNSW)

>From 0a7258aabb95a811e450e97308d9e32c905b326b Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 19 Jun 2025 15:02:18 +0100
Subject: [PATCH 3/3] Add separate TruncFlagsTy

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 56 ++++++++++++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 14 ++++-
 2 files changed, 54 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 561485c54b67e..da28ef8affc4c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -595,7 +595,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
-    WrappingOp,
+    OverflowingBinOp,
+    Trunc,
     DisjointOp,
     PossiblyExactOp,
     GEPOp,
@@ -612,6 +613,13 @@ class VPIRFlags {
     WrapFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {}
   };
 
+  struct TruncFlagsTy {
+    char HasNUW : 1;
+    char HasNSW : 1;
+
+    TruncFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {}
+  };
+
   struct DisjointFlagsTy {
     char IsDisjoint : 1;
     DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
@@ -643,6 +651,7 @@ class VPIRFlags {
   union {
     CmpInst::Predicate CmpPredicate;
     WrapFlagsTy WrapFlags;
+    TruncFlagsTy TruncFlags;
     DisjointFlagsTy DisjointFlags;
     ExactFlagsTy ExactFlags;
     GEPNoWrapFlags GEPFlags;
@@ -661,9 +670,12 @@ class VPIRFlags {
     } else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
       OpType = OperationType::DisjointOp;
       DisjointFlags.IsDisjoint = Op->isDisjoint();
-    } else if (isa<OverflowingBinaryOperator, TruncInst>(&I)) {
-      OpType = OperationType::WrappingOp;
-      WrapFlags = {I.hasNoUnsignedWrap(), I.hasNoSignedWrap()};
+    } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
+      OpType = OperationType::OverflowingBinOp;
+      WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
+    } else if (auto *Op = dyn_cast<TruncInst>(&I)) {
+      OpType = OperationType::Trunc;
+      TruncFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
     } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
       OpType = OperationType::PossiblyExactOp;
       ExactFlags.IsExact = Op->isExact();
@@ -686,7 +698,7 @@ class VPIRFlags {
       : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
   VPIRFlags(WrapFlagsTy WrapFlags)
-      : OpType(OperationType::WrappingOp), WrapFlags(WrapFlags) {}
+      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
 
   VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
@@ -710,10 +722,14 @@ class VPIRFlags {
     // NOTE: This needs to be kept in-sync with
     // Instruction::dropPoisonGeneratingFlags.
     switch (OpType) {
-    case OperationType::WrappingOp:
+    case OperationType::OverflowingBinOp:
       WrapFlags.HasNUW = false;
       WrapFlags.HasNSW = false;
       break;
+    case OperationType::Trunc:
+      TruncFlags.HasNUW = false;
+      TruncFlags.HasNSW = false;
+      break;
     case OperationType::DisjointOp:
       DisjointFlags.IsDisjoint = false;
       break;
@@ -739,10 +755,14 @@ class VPIRFlags {
   /// Apply the IR flags to \p I.
   void applyFlags(Instruction &I) const {
     switch (OpType) {
-    case OperationType::WrappingOp:
+    case OperationType::OverflowingBinOp:
       I.setHasNoUnsignedWrap(WrapFlags.HasNUW);
       I.setHasNoSignedWrap(WrapFlags.HasNSW);
       break;
+    case OperationType::Trunc:
+      I.setHasNoUnsignedWrap(TruncFlags.HasNUW);
+      I.setHasNoSignedWrap(TruncFlags.HasNSW);
+      break;
     case OperationType::DisjointOp:
       cast<PossiblyDisjointInst>(&I)->setIsDisjoint(DisjointFlags.IsDisjoint);
       break;
@@ -799,15 +819,25 @@ class VPIRFlags {
   }
 
   bool hasNoUnsignedWrap() const {
-    assert(OpType == OperationType::WrappingOp &&
-           "recipe doesn't have a NUW flag");
-    return WrapFlags.HasNUW;
+    switch (OpType) {
+    case OperationType::OverflowingBinOp:
+      return WrapFlags.HasNUW;
+    case OperationType::Trunc:
+      return TruncFlags.HasNUW;
+    default:
+      llvm_unreachable("recipe doesn't have a NUW flag");
+    }
   }
 
   bool hasNoSignedWrap() const {
-    assert(OpType == OperationType::WrappingOp &&
-           "recipe doesn't have a NSW flag");
-    return WrapFlags.HasNSW;
+    switch (OpType) {
+    case OperationType::OverflowingBinOp:
+      return WrapFlags.HasNSW;
+    case OperationType::Trunc:
+      return TruncFlags.HasNSW;
+    default:
+      llvm_unreachable("recipe doesn't have a NSW flag");
+    }
   }
 
   bool isDisjoint() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 303f8e9f20688..3a12e0eb6900f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1627,10 +1627,12 @@ VPIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
 #if !defined(NDEBUG)
 bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
   switch (OpType) {
-  case OperationType::WrappingOp:
+  case OperationType::OverflowingBinOp:
     return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
-           Opcode == Instruction::Mul || Opcode == Instruction::Trunc ||
+           Opcode == Instruction::Mul ||
            Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
+  case OperationType::Trunc:
+    return Opcode == Instruction::Trunc;
   case OperationType::DisjointOp:
     return Opcode == Instruction::Or;
   case OperationType::PossiblyExactOp:
@@ -1672,12 +1674,18 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
     if (ExactFlags.IsExact)
       O << " exact";
     break;
-  case OperationType::WrappingOp:
+  case OperationType::OverflowingBinOp:
     if (WrapFlags.HasNUW)
       O << " nuw";
     if (WrapFlags.HasNSW)
       O << " nsw";
     break;
+  case OperationType::Trunc:
+    if (TruncFlags.HasNUW)
+      O << " nuw";
+    if (TruncFlags.HasNSW)
+      O << " nsw";
+    break;
   case OperationType::FPMathOp:
     getFastMathFlags().print(O);
     break;



More information about the llvm-commits mailing list