[Mlir-commits] [mlir] 6fa87ec - [ADT] Deprecate is_splat and replace all uses with all_equal
Jakub Kuderski
llvmlistbot at llvm.org
Tue Aug 23 08:42:12 PDT 2022
Author: Jakub Kuderski
Date: 2022-08-23T11:36:27-04:00
New Revision: 6fa87ec10fce41709529abe5b94a6b18ad2062c7
URL: https://github.com/llvm/llvm-project/commit/6fa87ec10fce41709529abe5b94a6b18ad2062c7
DIFF: https://github.com/llvm/llvm-project/commit/6fa87ec10fce41709529abe5b94a6b18ad2062c7.diff
LOG: [ADT] Deprecate is_splat and replace all uses with all_equal
See the discussion thread for more details:
https://discourse.llvm.org/t/adt-is-splat-and-empty-ranges/64692
Reviewed By: dblaikie
Differential Revision: https://reviews.llvm.org/D132335
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
llvm/lib/Analysis/InstructionSimplify.cpp
llvm/lib/Analysis/VectorUtils.cpp
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/IR/Instructions.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/Scalar/NewGVN.cpp
llvm/unittests/ADT/STLExtrasTest.cpp
mlir/include/mlir/IR/OpBase.td
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/mlir-tblgen/predicate.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 4784414995914..0fde95878a2a0 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1795,12 +1795,20 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
}
/// Returns true if Range consists of the same value repeated multiple times.
-template <typename R> bool is_splat(R &&Range) {
+template <typename R>
+LLVM_DEPRECATED(
+ "Use 'all_equal(Range)' or '!empty(Range) && all_equal(Range)' instead.",
+ "all_equal")
+bool is_splat(R &&Range) {
return !llvm::empty(Range) && all_equal(Range);
}
/// Returns true if Values consists of the same value repeated multiple times.
-template <typename T> bool is_splat(std::initializer_list<T> Values) {
+template <typename T>
+LLVM_DEPRECATED(
+ "Use 'all_equal(Values)' or '!empty(Values) && all_equal(Values)' instead.",
+ "all_equal")
+bool is_splat(std::initializer_list<T> Values) {
return is_splat<std::initializer_list<T>>(std::move(Values));
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d5f90d6fdc680..8b08d7c9228fd 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4997,7 +4997,7 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1,
// value type is same as the input vectors' type.
if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
if (Q.isUndefValue(Op1) && RetTy == InVecTy &&
- is_splat(OpShuf->getShuffleMask()))
+ all_equal(OpShuf->getShuffleMask()))
return Op0;
// All remaining transformation depend on the value of the mask, which is
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index c4795a80ead24..57a373056b2b9 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -398,7 +398,7 @@ bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
// FIXME: We can safely allow undefs here. If Index was specified, we will
// check that the mask elt is defined at the required index.
- if (!is_splat(Shuf->getShuffleMask()))
+ if (!all_equal(Shuf->getShuffleMask()))
return false;
// Match any index.
@@ -478,7 +478,7 @@ bool llvm::widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
if (SliceFront < 0) {
// Negative values (undef or other "sentinel" values) must be equal across
// the entire slice.
- if (!is_splat(MaskSlice))
+ if (!all_equal(MaskSlice))
return false;
ScaledMask.push_back(SliceFront);
} else {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 11fb6fe1c165e..07fb54e95af08 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23713,7 +23713,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
// demanded elements analysis. It is further limited to not change a splat
// of an inserted scalar because that may be optimized better by
// load-folding or other target-specific behaviors.
- if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
+ if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
// binop (splat X), (splat C) --> splat (binop X, C)
@@ -23722,7 +23722,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
Shuf0->getMask());
}
- if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
+ if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
// binop (splat C), (splat X) --> splat (binop C, X)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d14a3a11ef74a..f244aafb1da8c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3287,7 +3287,7 @@ void SelectionDAGBuilder::visitSelect(const User &I) {
Flags.copyFMF(*FPOp);
// Min/max matching is only viable if all output VTs are the same.
- if (is_splat(ValueVTs)) {
+ if (all_equal(ValueVTs)) {
EVT VT = ValueVTs[0];
LLVMContext &Ctx = *DAG.getContext();
auto &TLI = DAG.getTargetLoweringInfo();
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 43518c436e3e2..1e06851a35573 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -2061,7 +2061,7 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
return false;
if (isa<ScalableVectorType>(V1->getType()))
- if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !is_splat(Mask))
+ if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !all_equal(Mask))
return false;
return true;
@@ -2152,7 +2152,7 @@ Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef<int> Mask,
Type *ResultTy) {
Type *Int32Ty = Type::getInt32Ty(ResultTy->getContext());
if (isa<ScalableVectorType>(ResultTy)) {
- assert(is_splat(Mask) && "Unexpected shuffle");
+ assert(all_equal(Mask) && "Unexpected shuffle");
Type *VecTy = VectorType::get(Int32Ty, Mask.size(), true);
if (Mask[0] == 0)
return Constant::getNullValue(VecTy);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 219bb1dfbac4b..95720611de176 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -12890,7 +12890,7 @@ static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
static bool isSplatShuffle(Value *V) {
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
- return is_splat(Shuf->getShuffleMask());
+ return all_equal(Shuf->getShuffleMask());
return false;
}
@@ -20827,7 +20827,7 @@ bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
// All non aggregate members of the type must have the same type
SmallVector<EVT> ValueVTs;
ComputeValueVTs(*this, DL, Ty, ValueVTs);
- return is_splat(ValueVTs);
+ return all_equal(ValueVTs);
}
bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index a9a930555b3c6..d68169db61cf2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -726,7 +726,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
InstCombiner::BuilderTy &Builder) {
auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
- is_splat(Shuf->getShuffleMask()) &&
+ all_equal(Shuf->getShuffleMask()) &&
Shuf->getType() == Shuf->getOperand(0)->getType()) {
// trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
// trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 46e4f6b7edd33..fd452574c9b07 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3141,7 +3141,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
ArrayRef<int> Mask;
if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
// Check whether every element of Mask is the same constant
- if (is_splat(Mask)) {
+ if (all_equal(Mask)) {
auto *VecTy = cast<VectorType>(SrcType);
auto *EltTy = cast<IntegerType>(VecTy->getElementType());
if (C->isSplat(EltTy->getBitWidth())) {
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 3dfd0eb799aef..20df0a80a27bf 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -3166,7 +3166,7 @@ bool NewGVN::singleReachablePHIPath(
make_filter_range(MP->operands(), ReachableOperandPred);
SmallVector<const Value *, 32> OperandList;
llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList));
- bool Okay = is_splat(OperandList);
+ bool Okay = all_equal(OperandList);
if (Okay)
return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]),
Second);
@@ -3261,7 +3261,7 @@ void NewGVN::verifyMemoryCongruency() const {
const MemoryDef *MD = cast<MemoryDef>(U);
return ValueToClass.lookup(MD->getMemoryInst());
});
- assert(is_splat(PhiOpClasses) &&
+ assert(all_equal(PhiOpClasses) &&
"All MemoryPhi arguments should be in the same class");
}
}
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 9b0af51658c24..0d931152c5db2 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -611,28 +611,6 @@ TEST(STLExtrasTest, AllEqualInitializerList) {
EXPECT_TRUE(all_equal({1, 1, 1}));
}
-TEST(STLExtrasTest, IsSplat) {
- std::vector<int> V;
- EXPECT_FALSE(is_splat(V));
-
- V.push_back(1);
- EXPECT_TRUE(is_splat(V));
-
- V.push_back(1);
- V.push_back(1);
- EXPECT_TRUE(is_splat(V));
-
- V.push_back(2);
- EXPECT_FALSE(is_splat(V));
-}
-
-TEST(STLExtrasTest, IsSplatInitializerList) {
- EXPECT_TRUE(is_splat({1}));
- EXPECT_TRUE(is_splat({1, 1}));
- EXPECT_FALSE(is_splat({1, 2}));
- EXPECT_TRUE(is_splat({1, 1, 1}));
-}
-
TEST(STLExtrasTest, to_address) {
int *V1 = new int;
EXPECT_EQ(V1, to_address(V1));
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index cbacb90b08114..65d537ecd68ef 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2432,7 +2432,7 @@ class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
// 1) all operands involved are of shaped type and
// 2) the indices are not out of range.
class TCopVTEtAreSameAt<list<int> indices> : CPred<
- "::llvm::is_splat(::llvm::map_range("
+ "::llvm::all_equal(::llvm::map_range("
"::mlir::ArrayRef<unsigned>({" # !interleave(indices, ", ") # "}), "
"[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
"}))">;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index c3669724ba8e5..05421bf0a0ac0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -77,8 +77,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
auto getOperandElementType = [](OpOperand *operand) {
return operand->get().getType().cast<ShapedType>().getElementType();
};
- if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
- getOperandElementType)))
+ if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
+ getOperandElementType)))
return failure();
// We can only handle the case where we have int/float elements.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index ef745a6a1bf3a..de42a602b696c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2871,9 +2871,9 @@ LogicalResult spirv::IAddCarryOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");
- if (!llvm::is_splat({operand1().getType(), operand2().getType(),
- resultType.getElementType(0),
- resultType.getElementType(1)}))
+ if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+ resultType.getElementType(0),
+ resultType.getElementType(1)}))
return emitOpError(
"expected all operand types and struct member types are the same");
@@ -2920,9 +2920,9 @@ LogicalResult spirv::ISubBorrowOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");
- if (!llvm::is_splat({operand1().getType(), operand2().getType(),
- resultType.getElementType(0),
- resultType.getElementType(1)}))
+ if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+ resultType.getElementType(0),
+ resultType.getElementType(1)}))
return emitOpError(
"expected all operand types and struct member types are the same");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 313ec3d40c527..8c6eae1dfb73d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1269,7 +1269,7 @@ struct ReorderElementwiseOpsOnTranspose final
// This is an elementwise op, so all transposed operands should have the
// same type. We need to additionally check that all transposes uses the
// same map.
- if (!llvm::is_splat(transposeMaps))
+ if (!llvm::all_equal(transposeMaps))
return rewriter.notifyMatchFailure(op, "
diff erent transpose map");
SmallVector<Value, 4> srcValues;
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index a67e41fe205a0..07d9b4ffb7ba6 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -102,7 +102,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
}
// CHECK-LABEL: OpJAdaptor::verify
-// CHECK: ::llvm::is_splat(::llvm::map_range(
+// CHECK: ::llvm::all_equal(::llvm::map_range(
// CHECK-SAME: ::mlir::ArrayRef<unsigned>({0, 2, 3}),
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8eebc8627277a..ec0e79ebdd834 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -41,7 +41,7 @@ static void collectAllDefs(StringRef selectedDialect,
if (selectedDialect.empty()) {
// If a dialect was not specified, ensure that all found defs belong to the
// same dialect.
- if (!llvm::is_splat(llvm::map_range(
+ if (!llvm::all_equal(llvm::map_range(
defs, [](const auto &def) { return def.getDialect(); }))) {
llvm::PrintFatalError("defs belonging to more than one dialect. Must "
"select one via '--(attr|type)defs-dialect'");
More information about the Mlir-commits
mailing list