[Mlir-commits] [llvm] [mlir] [llvm] Add `getSingleElement` helper and use in MLIR (PR #131460)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 16 03:25:14 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/131460
>From 44aa2aa100b4c081f128aeb8b8af42e8bc022564 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 16 Mar 2025 11:21:49 +0100
Subject: [PATCH 1/2] [llvm][ADT] Add `getSingleElement` helper
---
llvm/include/llvm/ADT/STLExtras.h | 8 ++++++++
llvm/unittests/ADT/STLExtrasTest.cpp | 22 ++++++++++++++++++++++
2 files changed, 30 insertions(+)
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
return B != E && std::next(B) == E;
}
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+ assert(hasSingleElement(C) && "expected container with single element");
+ return *adl_begin(C);
+}
+
/// Return a range covering \p RangeOrContainer with the first N elements
/// excluded.
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index dbb094b0a3088..df8c0a4e4819b 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1016,6 +1016,28 @@ TEST(STLExtrasTest, hasSingleElement) {
EXPECT_FALSE(hasSingleElement(S));
}
+TEST(STLExtrasTest, getSingleElement) {
+ // Note: Asserting behavior of getSingleElement cannot be tested because the
+ // program would crash.
+ const std::vector<int> V1 = {7};
+ EXPECT_EQ(getSingleElement(V1), 7);
+
+ std::vector<int> V2 = {8};
+ EXPECT_EQ(getSingleElement(V2), 8);
+
+ SmallVector<int> V3 {9};
+ EXPECT_EQ(getSingleElement(V3), 9);
+
+ std::list<int> L1 = {10};
+ EXPECT_EQ(getSingleElement(L1), 10);
+
+ // Make sure that we use the `begin`/`end` functions from `some_namespace`,
+ // using ADL.
+ some_namespace::some_struct S;
+ S.data = V2;
+ EXPECT_EQ(getSingleElement(S), 8);
+}
+
TEST(STLExtrasTest, hasNItems) {
const std::list<int> V0 = {}, V1 = {1}, V2 = {1, 2};
const std::list<int> V3 = {1, 3, 5};
>From 62173153f73910cdb2023700badeb27d5633c724 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 15 Mar 2025 16:45:53 +0100
Subject: [PATCH 2/2] [llvm] Add `getSingleElement` helper and use in MLIR
---
mlir/include/mlir/Dialect/CommonFolders.h | 6 ++--
mlir/lib/Analysis/SliceAnalysis.cpp | 2 +-
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 +-
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 6 ++--
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 29 +++++++++----------
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 6 ++--
.../Dialect/Affine/Transforms/LoopFusion.cpp | 3 +-
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 6 ++--
.../SubsetInsertionOpInterfaceImpl.cpp | 13 ++++-----
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
.../Quant/Transforms/StripFuncQuantTypes.cpp | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Transforms/StructuralTypeConversions.cpp | 15 ++++------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 ++--
.../BufferizableOpInterfaceImpl.cpp | 4 +--
.../Transforms/SparseIterationToScf.cpp | 12 ++------
.../Transforms/SparseTensorCodegen.cpp | 16 ++++------
.../Transforms/Sparsification.cpp | 5 ++--
.../Transforms/Utils/SparseTensorIterator.cpp | 13 +++------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 +-
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp | 3 +-
mlir/test/lib/Analysis/TestCFGLoopInfo.cpp | 2 +-
22 files changed, 59 insertions(+), 102 deletions(-)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "unary op takes one operands");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
@@ -268,8 +267,7 @@ template <
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
- assert(operands.size() == 1 && "Cast op takes one operand");
- if (!operands[0])
+ if (!llvm::getSingleElement(operands))
return {};
static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
- parentOp->getRegion(0).getBlocks().size() == 1);
+ llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Type srcType = adaptor.getOperands().front().getType();
+ Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() == 1);
- Value cst = adaptor.getOperands().front();
+ Value cst = llvm::getSingleElement(adaptor.getOperands());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
"splat is not a composite construct");
}
- assert(cc.getConstituents().size() == 1);
- scalar = cc.getConstituents().front();
+ scalar = llvm::getSingleElement(cc.getConstituents());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
- assert(dim.size() == 1);
- dynDims.emplace_back(dim[0]);
+ dynDims.emplace_back(llvm::getSingleElement(dim));
}
// same for device
for (auto device : adaptor.getDeviceDynamic()) {
- assert(device.size() == 1);
- dynDevice.emplace_back(device[0]);
+ dynDevice.emplace_back(llvm::getSingleElement(device));
}
// To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
typeConverter.addConversion([](Type type) { return type; });
// convert mesh::ShardingType to a tuple of RankedTensorTypes
- typeConverter.addConversion(
- [](ShardingType type,
- SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- auto i16 = IntegerType::get(type.getContext(), 16);
- auto i64 = IntegerType::get(type.getContext(), 64);
- std::array<int64_t, 2> shp = {ShapedType::kDynamic,
- ShapedType::kDynamic};
- results.emplace_back(RankedTensorType::get(shp, i16));
- results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
- results.emplace_back(RankedTensorType::get(shp, i64));
- return success();
- });
+ typeConverter.addConversion([](ShardingType type,
+ SmallVectorImpl<Type> &results)
+ -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return success();
+ });
// To 'extract' components, a UnrealizedConversionCastOp is expected
// to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
}
applyOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
}
minMaxOp->erase();
- assert(foldResults.size() == 1 && "expected 1 folded result");
- return foldResults.front();
+ return llvm::getSingleElement(foldResults);
}
OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
SmallVector<Operation *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
- assert(sibLoadOpInsts.size() == 1);
- Operation *sibLoadOpInst = sibLoadOpInsts[0];
+ Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
AffineForOp target) {
SmallVector<AffineForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getInputs().size() == 1 && "expected single input");
- return copyOp.getInputsMutable()[0];
+ return llvm::getSingleElement(copyOp.getInputsMutable());
}
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+ return equivalenceFn(candidate,
+ llvm::getSingleElement(copyOp.getOutputs()));
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return copyOp.getOutputs()[0];
+ return llvm::getSingleElement(copyOp.getOutputs());
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
- assert(copyOp.getOutputs().size() == 1 && "expected single output");
- return {copyOp.getOutputs()[0]};
+ return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
/// extending the lifetime of allocations.
static bool lastNonTerminatorInRegion(Operation *op) {
return op->getNextNode() == op->getBlock()->getTerminator() &&
- op->getParentRegion()->getBlocks().size() == 1;
+ llvm::hasSingleElement(op->getParentRegion()->getBlocks());
}
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
- assert(inputs.size() == 1);
- return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+ return builder.create<quant::StorageCastOp>(loc, type,
+ llvm::getSingleElement(inputs));
}
public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
- assert(region->getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(region->getBlocks()) &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(
- op.getLoc(), getSingleValue(adaptor.getLowerBound()),
- getSingleValue(adaptor.getUpperBound()),
- getSingleValue(adaptor.getStep()),
+ op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+ llvm::getSingleElement(adaptor.getUpperBound()),
+ llvm::getSingleElement(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
TypeRange dstTypes) const {
IfOp newOp = rewriter.create<IfOp>(
- op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+ op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+ true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
scf::ForOp target) {
SmallVector<scf::ForOp, 8> res;
- for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
- assert(loops.size() == 1);
- res.push_back(loops[0]);
- }
+ for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+ res.push_back(llvm::getSingleElement(loops));
return res;
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// TODO: Support multiple blocks.
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
- assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
auto yieldOp = cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
// Construct the iteration space.
SparseIterationSpace space(loc, rewriter,
- getSingleValue(adaptor.getTensor()), 0,
+ llvm::getSingleElement(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());
SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
- Value valBuf =
- rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+ Value valBuf = rewriter.create<ToValuesOp>(
+ loc, llvm::getSingleElement(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
- assert(values.size() == 1 && "expected single value");
- return values.front();
-}
-
/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
- Value values = getSingleValue(adaptor.getValues());
- Value filled = getSingleValue(adaptor.getFilled());
- Value added = getSingleValue(adaptor.getAdded());
- Value count = getSingleValue(adaptor.getCount());
+ Value values = llvm::getSingleElement(adaptor.getValues());
+ Value filled = llvm::getSingleElement(adaptor.getFilled());
+ Value added = llvm::getSingleElement(adaptor.getAdded());
+ Value count = llvm::getSingleElement(adaptor.getCount());
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();
@@ -1041,7 +1035,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
SmallVector<Value> params = llvm::to_vector(desc.getFields());
SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
params.append(flatIndices.begin(), flatIndices.end());
- params.push_back(getSingleValue(adaptor.getScalar()));
+ params.push_back(llvm::getSingleElement(adaptor.getScalar()));
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index bf12dc8ae05cc..badcc583bbca2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -521,9 +521,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value ptr = genSubscript(env, builder, t, args);
if (llvm::isa<TensorType>(ptr.getType())) {
assert(env.options().sparseEmitStrategy ==
- SparseEmitStrategy::kSparseIterator &&
- args.size() == 1);
- return builder.create<ExtractValOp>(loc, ptr, args.front());
+ SparseEmitStrategy::kSparseIterator);
+ return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
}
return builder.create<memref::LoadOp>(loc, ptr, args);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index ef95fcc84bd90..aad5e97ed14ab 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -1106,9 +1106,7 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
return {notLegit};
});
-
- assert(r.size() == 1);
- return r.front();
+ return llvm::getSingleElement(r);
}
Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
@@ -1120,8 +1118,7 @@ Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, size)};
});
- assert(r.size() == 1);
- return r.front();
+ return llvm::getSingleElement(r);
}
ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
@@ -1145,7 +1142,6 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
/*beforeBuilder=*/
[this](OpBuilder &b, Location l, ValueRange ivs) {
ValueRange isFirst = linkNewScope(ivs);
- assert(isFirst.size() == 1);
scf::ValueVector cont =
genWhenInBound(b, l, *wrap, C_FALSE,
[this, isFirst](OpBuilder &b, Location l,
@@ -1155,7 +1151,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
genCrdNotLegitPredicate(b, l, wrapCrd);
Value crd = fromWrapCrd(b, l, wrapCrd);
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
- ret = ORI(ret, isFirst.front());
+ ret = ORI(ret, llvm::getSingleElement(isFirst));
return {ret};
});
b.create<scf::ConditionOp>(l, cont.front(), ivs);
@@ -1200,8 +1196,7 @@ Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, subSect.subSectSz)};
});
- assert(r.size() == 1);
- return r.front();
+ return llvm::getSingleElement(r);
}
Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index ad6d4532bfd4a..7fb5b117676a9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -833,8 +833,7 @@ makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
moduleTranslation, &phis)))
return llvm::createStringError(
"failed to inline `combiner` region of `omp.declare_reduction`");
- assert(phis.size() == 1);
- result = phis[0];
+ result = llvm::getSingleElement(phis);
return builder.saveIP();
};
return gen;
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 964d94c9c0a46..589c21105af10 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -388,8 +388,7 @@ Value CodeGen::genSingleExpr(const ast::Expr *expr) {
.Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
[&](auto derivedNode) {
SmallVector<Value> results = this->genExprImpl(derivedNode);
- assert(results.size() == 1 && "expected single expression result");
- return results[0];
+ return llvm::getSingleElement(results);
});
}
diff --git a/mlir/test/lib/Analysis/TestCFGLoopInfo.cpp b/mlir/test/lib/Analysis/TestCFGLoopInfo.cpp
index 4f4406cbb5bfa..7535994955f57 100644
--- a/mlir/test/lib/Analysis/TestCFGLoopInfo.cpp
+++ b/mlir/test/lib/Analysis/TestCFGLoopInfo.cpp
@@ -53,7 +53,7 @@ void TestCFGLoopInfo::runOnOperation() {
}
llvm::errs() << "\n";
- if (region.getBlocks().size() == 1) {
+ if (llvm::hasSingleElement(region.getBlocks())) {
llvm::errs() << "no loops\n";
return;
}
More information about the Mlir-commits
mailing list