[Mlir-commits] [mlir] 6fa87ec - [ADT] Deprecate is_splat and replace all uses with all_equal

Jakub Kuderski llvmlistbot at llvm.org
Tue Aug 23 08:42:12 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-23T11:36:27-04:00
New Revision: 6fa87ec10fce41709529abe5b94a6b18ad2062c7

URL: https://github.com/llvm/llvm-project/commit/6fa87ec10fce41709529abe5b94a6b18ad2062c7
DIFF: https://github.com/llvm/llvm-project/commit/6fa87ec10fce41709529abe5b94a6b18ad2062c7.diff

LOG: [ADT] Deprecate is_splat and replace all uses with all_equal

See the discussion thread for more details:
https://discourse.llvm.org/t/adt-is-splat-and-empty-ranges/64692

Reviewed By: dblaikie

Differential Revision: https://reviews.llvm.org/D132335

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/IR/Instructions.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/Scalar/NewGVN.cpp
    llvm/unittests/ADT/STLExtrasTest.cpp
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/mlir-tblgen/predicate.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 4784414995914..0fde95878a2a0 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1795,12 +1795,20 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
 }
 
 /// Returns true if Range consists of the same value repeated multiple times.
-template <typename R> bool is_splat(R &&Range) {
+template <typename R>
+LLVM_DEPRECATED(
+    "Use 'all_equal(Range)' or '!empty(Range) && all_equal(Range)' instead.",
+    "all_equal")
+bool is_splat(R &&Range) {
   return !llvm::empty(Range) && all_equal(Range);
 }
 
 /// Returns true if Values consists of the same value repeated multiple times.
-template <typename T> bool is_splat(std::initializer_list<T> Values) {
+template <typename T>
+LLVM_DEPRECATED(
+    "Use 'all_equal(Values)' or '!empty(Values) && all_equal(Values)' instead.",
+    "all_equal")
+bool is_splat(std::initializer_list<T> Values) {
   return is_splat<std::initializer_list<T>>(std::move(Values));
 }
 

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d5f90d6fdc680..8b08d7c9228fd 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4997,7 +4997,7 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1,
   // value type is same as the input vectors' type.
   if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
     if (Q.isUndefValue(Op1) && RetTy == InVecTy &&
-        is_splat(OpShuf->getShuffleMask()))
+        all_equal(OpShuf->getShuffleMask()))
       return Op0;
 
   // All remaining transformation depend on the value of the mask, which is

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index c4795a80ead24..57a373056b2b9 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -398,7 +398,7 @@ bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
     // FIXME: We can safely allow undefs here. If Index was specified, we will
     //        check that the mask elt is defined at the required index.
-    if (!is_splat(Shuf->getShuffleMask()))
+    if (!all_equal(Shuf->getShuffleMask()))
       return false;
 
     // Match any index.
@@ -478,7 +478,7 @@ bool llvm::widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
     if (SliceFront < 0) {
       // Negative values (undef or other "sentinel" values) must be equal across
       // the entire slice.
-      if (!is_splat(MaskSlice))
+      if (!all_equal(MaskSlice))
         return false;
       ScaledMask.push_back(SliceFront);
     } else {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 11fb6fe1c165e..07fb54e95af08 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23713,7 +23713,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
     // demanded elements analysis. It is further limited to not change a splat
     // of an inserted scalar because that may be optimized better by
     // load-folding or other target-specific behaviors.
-    if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
+    if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
       // binop (splat X), (splat C) --> splat (binop X, C)
@@ -23722,7 +23722,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
                                   Shuf0->getMask());
     }
-    if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
+    if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
       // binop (splat C), (splat X) --> splat (binop C, X)

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d14a3a11ef74a..f244aafb1da8c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3287,7 +3287,7 @@ void SelectionDAGBuilder::visitSelect(const User &I) {
     Flags.copyFMF(*FPOp);
 
   // Min/max matching is only viable if all output VTs are the same.
-  if (is_splat(ValueVTs)) {
+  if (all_equal(ValueVTs)) {
     EVT VT = ValueVTs[0];
     LLVMContext &Ctx = *DAG.getContext();
     auto &TLI = DAG.getTargetLoweringInfo();

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 43518c436e3e2..1e06851a35573 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -2061,7 +2061,7 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
       return false;
 
   if (isa<ScalableVectorType>(V1->getType()))
-    if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !is_splat(Mask))
+    if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !all_equal(Mask))
       return false;
 
   return true;
@@ -2152,7 +2152,7 @@ Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef<int> Mask,
                                                           Type *ResultTy) {
   Type *Int32Ty = Type::getInt32Ty(ResultTy->getContext());
   if (isa<ScalableVectorType>(ResultTy)) {
-    assert(is_splat(Mask) && "Unexpected shuffle");
+    assert(all_equal(Mask) && "Unexpected shuffle");
     Type *VecTy = VectorType::get(Int32Ty, Mask.size(), true);
     if (Mask[0] == 0)
       return Constant::getNullValue(VecTy);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 219bb1dfbac4b..95720611de176 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -12890,7 +12890,7 @@ static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
 
 static bool isSplatShuffle(Value *V) {
   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
-    return is_splat(Shuf->getShuffleMask());
+    return all_equal(Shuf->getShuffleMask());
   return false;
 }
 
@@ -20827,7 +20827,7 @@ bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
   // All non aggregate members of the type must have the same type
   SmallVector<EVT> ValueVTs;
   ComputeValueVTs(*this, DL, Ty, ValueVTs);
-  return is_splat(ValueVTs);
+  return all_equal(ValueVTs);
 }
 
 bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index a9a930555b3c6..d68169db61cf2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -726,7 +726,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
                                        InstCombiner::BuilderTy &Builder) {
   auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
   if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
-      is_splat(Shuf->getShuffleMask()) &&
+      all_equal(Shuf->getShuffleMask()) &&
       Shuf->getType() == Shuf->getOperand(0)->getType()) {
     // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
     // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 46e4f6b7edd33..fd452574c9b07 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3141,7 +3141,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
   ArrayRef<int> Mask;
   if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
     // Check whether every element of Mask is the same constant
-    if (is_splat(Mask)) {
+    if (all_equal(Mask)) {
       auto *VecTy = cast<VectorType>(SrcType);
       auto *EltTy = cast<IntegerType>(VecTy->getElementType());
       if (C->isSplat(EltTy->getBitWidth())) {

diff  --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 3dfd0eb799aef..20df0a80a27bf 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -3166,7 +3166,7 @@ bool NewGVN::singleReachablePHIPath(
       make_filter_range(MP->operands(), ReachableOperandPred);
   SmallVector<const Value *, 32> OperandList;
   llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList));
-  bool Okay = is_splat(OperandList);
+  bool Okay = all_equal(OperandList);
   if (Okay)
     return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]),
                                   Second);
@@ -3261,7 +3261,7 @@ void NewGVN::verifyMemoryCongruency() const {
                        const MemoryDef *MD = cast<MemoryDef>(U);
                        return ValueToClass.lookup(MD->getMemoryInst());
                      });
-      assert(is_splat(PhiOpClasses) &&
+      assert(all_equal(PhiOpClasses) &&
              "All MemoryPhi arguments should be in the same class");
     }
   }

diff  --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 9b0af51658c24..0d931152c5db2 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -611,28 +611,6 @@ TEST(STLExtrasTest, AllEqualInitializerList) {
   EXPECT_TRUE(all_equal({1, 1, 1}));
 }
 
-TEST(STLExtrasTest, IsSplat) {
-  std::vector<int> V;
-  EXPECT_FALSE(is_splat(V));
-
-  V.push_back(1);
-  EXPECT_TRUE(is_splat(V));
-
-  V.push_back(1);
-  V.push_back(1);
-  EXPECT_TRUE(is_splat(V));
-
-  V.push_back(2);
-  EXPECT_FALSE(is_splat(V));
-}
-
-TEST(STLExtrasTest, IsSplatInitializerList) {
-  EXPECT_TRUE(is_splat({1}));
-  EXPECT_TRUE(is_splat({1, 1}));
-  EXPECT_FALSE(is_splat({1, 2}));
-  EXPECT_TRUE(is_splat({1, 1, 1}));
-}
-
 TEST(STLExtrasTest, to_address) {
   int *V1 = new int;
   EXPECT_EQ(V1, to_address(V1));

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index cbacb90b08114..65d537ecd68ef 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2432,7 +2432,7 @@ class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
 // 1) all operands involved are of shaped type and
 // 2) the indices are not out of range.
 class TCopVTEtAreSameAt<list<int> indices> : CPred<
-  "::llvm::is_splat(::llvm::map_range("
+  "::llvm::all_equal(::llvm::map_range("
       "::mlir::ArrayRef<unsigned>({" # !interleave(indices, ", ") # "}), "
       "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
       "}))">;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index c3669724ba8e5..05421bf0a0ac0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -77,8 +77,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     auto getOperandElementType = [](OpOperand *operand) {
       return operand->get().getType().cast<ShapedType>().getElementType();
     };
-    if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
-                                        getOperandElementType)))
+    if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
+                                         getOperandElementType)))
       return failure();
 
     // We can only handle the case where we have int/float elements.

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index ef745a6a1bf3a..de42a602b696c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2871,9 +2871,9 @@ LogicalResult spirv::IAddCarryOp::verify() {
   if (resultType.getNumElements() != 2)
     return emitOpError("expected result struct type containing two members");
 
-  if (!llvm::is_splat({operand1().getType(), operand2().getType(),
-                       resultType.getElementType(0),
-                       resultType.getElementType(1)}))
+  if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+                        resultType.getElementType(0),
+                        resultType.getElementType(1)}))
     return emitOpError(
         "expected all operand types and struct member types are the same");
 
@@ -2920,9 +2920,9 @@ LogicalResult spirv::ISubBorrowOp::verify() {
   if (resultType.getNumElements() != 2)
     return emitOpError("expected result struct type containing two members");
 
-  if (!llvm::is_splat({operand1().getType(), operand2().getType(),
-                       resultType.getElementType(0),
-                       resultType.getElementType(1)}))
+  if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+                        resultType.getElementType(0),
+                        resultType.getElementType(1)}))
     return emitOpError(
         "expected all operand types and struct member types are the same");
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 313ec3d40c527..8c6eae1dfb73d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1269,7 +1269,7 @@ struct ReorderElementwiseOpsOnTranspose final
     // This is an elementwise op, so all transposed operands should have the
     // same type. We need to additionally check that all transposes uses the
     // same map.
-    if (!llvm::is_splat(transposeMaps))
+    if (!llvm::all_equal(transposeMaps))
       return rewriter.notifyMatchFailure(op, "
diff erent transpose map");
 
     SmallVector<Value, 4> srcValues;

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index a67e41fe205a0..07d9b4ffb7ba6 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -102,7 +102,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
 }
 
 // CHECK-LABEL: OpJAdaptor::verify
-// CHECK:      ::llvm::is_splat(::llvm::map_range(
+// CHECK:      ::llvm::all_equal(::llvm::map_range(
 // CHECK-SAME:   ::mlir::ArrayRef<unsigned>({0, 2, 3}),
 // CHECK-SAME:   [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
 // CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8eebc8627277a..ec0e79ebdd834 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -41,7 +41,7 @@ static void collectAllDefs(StringRef selectedDialect,
   if (selectedDialect.empty()) {
     // If a dialect was not specified, ensure that all found defs belong to the
     // same dialect.
-    if (!llvm::is_splat(llvm::map_range(
+    if (!llvm::all_equal(llvm::map_range(
             defs, [](const auto &def) { return def.getDialect(); }))) {
       llvm::PrintFatalError("defs belonging to more than one dialect. Must "
                             "select one via '--(attr|type)defs-dialect'");


        


More information about the Mlir-commits mailing list