[Mlir-commits] [llvm] [mlir] [llvm] Add `getSingleElement` helper and use in MLIR (PR #131460)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 15 08:48:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-flang-openmp

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds a new helper function: `getSingleElement`

This function asserts that the container has a single element and then returns that element. This helper function is useful during 1:N dialect conversions, where certain `ValueRange`s (return from the adaptor) are known to have a single value.

Also update a few places that should use `hasSingleElement` instead of `.size() = 1`.



---

Patch is 26.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131460.diff


23 Files Affected:

- (modified) llvm/include/llvm/ADT/STLExtras.h (+8) 
- (modified) mlir/include/mlir/Dialect/CommonFolders.h (+2-4) 
- (modified) mlir/lib/Analysis/SliceAnalysis.cpp (+1-1) 
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+1-2) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+2-4) 
- (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+13-16) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+5-8) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+5-10) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-9) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+5-11) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+2-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+4-9) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+1-2) 
- (modified) mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp (+1-2) 
- (modified) mlir/test/lib/Analysis/TestCFGLoopInfo.cpp (+1-1) 


``````````diff
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 78b7e94c2b3a1..dc0443c9244be 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -325,6 +325,14 @@ template <typename ContainerTy> bool hasSingleElement(ContainerTy &&C) {
   return B != E && std::next(B) == E;
 }
 
+/// Asserts that the given container has a single element and returns that
+/// element.
+template <typename ContainerTy>
+decltype(auto) getSingleElement(ContainerTy &&C) {
+  assert(hasSingleElement(C) && "expected container with single element");
+  return *adl_begin(C);
+}
+
 /// Return a range covering \p RangeOrContainer with the first N elements
 /// excluded.
 template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N = 1) {
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 6f497a259262a..b5a12426aff80 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -196,8 +196,7 @@ template <class AttrElementT,
               function_ref<std::optional<ElementValueT>(ElementValueT)>>
 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
                                       CalculationT &&calculate) {
-  assert(operands.size() == 1 && "unary op takes one operands");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
@@ -268,8 +267,7 @@ template <
     class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
                           CalculationT &&calculate) {
-  assert(operands.size() == 1 && "Cast op takes one operand");
-  if (!operands[0])
+  if (!llvm::getSingleElement(operands))
     return {};
 
   static_assert(
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8803ba994b2c1..e01cb3a080b5c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
       // into us. For now, just bail.
       if (parentOp && backwardSlice->count(parentOp) == 0) {
         assert(parentOp->getNumRegions() == 1 &&
-               parentOp->getRegion(0).getBlocks().size() == 1);
+               llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
         getBackwardSliceImpl(parentOp, backwardSlice, options);
       }
     } else {
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1f2781aa82114..9c4dfa27b1447 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Type srcType = adaptor.getOperands().front().getType();
+    Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
     Type dstType = this->getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 1b0f023527891..df2da138d3b52 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() == 1);
-    Value cst = adaptor.getOperands().front();
+    Value cst = llvm::getSingleElement(adaptor.getOperands());
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
       return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
                                          "splat is not a composite construct");
     }
 
-    assert(cc.getConstituents().size() == 1);
-    scalar = cc.getConstituents().front();
+    scalar = llvm::getSingleElement(cc.getConstituents());
 
     auto coopType = getTypeConverter()->convertType(op.getType());
     if (!coopType)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b0884d321bc8a..33391995885a4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     SmallVector<Value> dynDims, dynDevice;
     for (auto dim : adaptor.getDimsDynamic()) {
       // type conversion should be 1:1 for ints
-      assert(dim.size() == 1);
-      dynDims.emplace_back(dim[0]);
+      dynDims.emplace_back(llvm::getSingleElement(dim));
     }
     // same for device
     for (auto device : adaptor.getDeviceDynamic()) {
-      assert(device.size() == 1);
-      dynDevice.emplace_back(device[0]);
+      dynDevice.emplace_back(llvm::getSingleElement(device));
     }
 
     // To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
     typeConverter.addConversion([](Type type) { return type; });
 
     // convert mesh::ShardingType to a tuple of RankedTensorTypes
-    typeConverter.addConversion(
-        [](ShardingType type,
-           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-          auto i16 = IntegerType::get(type.getContext(), 16);
-          auto i64 = IntegerType::get(type.getContext(), 64);
-          std::array<int64_t, 2> shp = {ShapedType::kDynamic,
-                                        ShapedType::kDynamic};
-          results.emplace_back(RankedTensorType::get(shp, i16));
-          results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
-          results.emplace_back(RankedTensorType::get(shp, i64));
-          return success();
-        });
+    typeConverter.addConversion([](ShardingType type,
+                                   SmallVectorImpl<Type> &results)
+                                    -> std::optional<LogicalResult> {
+      auto i16 = IntegerType::get(type.getContext(), 16);
+      auto i64 = IntegerType::get(type.getContext(), 64);
+      std::array<int64_t, 2> shp = {ShapedType::kDynamic, ShapedType::kDynamic};
+      results.emplace_back(RankedTensorType::get(shp, i16));
+      results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+      results.emplace_back(RankedTensorType::get(shp, i64));
+      return success();
+    });
 
     // To 'extract' components, a UnrealizedConversionCastOp is expected
     // to define the input
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..9c5b9e82cd5e0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
   }
 
   applyOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
   }
 
   minMaxOp->erase();
-  assert(foldResults.size() == 1 && "expected 1 folded result");
-  return foldResults.front();
+  return llvm::getSingleElement(foldResults);
 }
 
 OpFoldResult
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bcba17bb21544..4b4eb9ce37b4c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
       SmallVector<Operation *, 2> sibLoadOpInsts;
       sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
       // Currently findSiblingNodeToFuse searches for siblings with one load.
-      assert(sibLoadOpInsts.size() == 1);
-      Operation *sibLoadOpInst = sibLoadOpInsts[0];
+      Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
 
       // Gather 'dstNode' load ops to 'memref'.
       SmallVector<Operation *, 2> dstLoadOpInsts;
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 71c6acba32d2e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
                                                ArrayRef<uint64_t> sizes,
                                                AffineForOp target) {
   SmallVector<AffineForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 6fcfa05468eea..55a09622644ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
                                                        linalg::CopyOp> {
   OpOperand &getSourceOperand(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getInputs().size() == 1 && "expected single input");
-    return copyOp.getInputsMutable()[0];
+    return llvm::getSingleElement(copyOp.getInputsMutable());
   }
 
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return equivalenceFn(candidate, copyOp.getOutputs()[0]);
+    return equivalenceFn(candidate,
+                         llvm::getSingleElement(copyOp.getOutputs()));
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return copyOp.getOutputs()[0];
+    return llvm::getSingleElement(copyOp.getOutputs());
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
     auto copyOp = cast<CopyOp>(op);
-    assert(copyOp.getOutputs().size() == 1 && "expected single output");
-    return {copyOp.getOutputs()[0]};
+    return {llvm::getSingleElement(copyOp.getOutputs())};
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
 /// extending the lifetime of allocations.
 static bool lastNonTerminatorInRegion(Operation *op) {
   return op->getNextNode() == op->getBlock()->getTerminator() &&
-         op->getParentRegion()->getBlocks().size() == 1;
+         llvm::hasSingleElement(op->getParentRegion()->getBlocks());
 }
 
 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 71b88d1be1b05..de834fed90e42 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {
 
   static Value materializeConversion(OpBuilder &builder, Type type,
                                      ValueRange inputs, Location loc) {
-    assert(inputs.size() == 1);
-    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+    return builder.create<quant::StorageCastOp>(loc, type,
+                                                llvm::getSingleElement(inputs));
   }
 
 public:
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e9d7dc1b847c6..ee46f9c97268b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
 static bool doesNotAliasExternalValue(Value value, Region *region,
                                       ValueRange exceptions,
                                       const OneShotAnalysisState &state) {
-  assert(region->getBlocks().size() == 1 &&
+  assert(llvm::hasSingleElement(region->getBlocks()) &&
          "expected region with single block");
   bool result = true;
   state.applyOnAliases(value, [&](Value alias) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c0589044c26ec..40d2e254fb7dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 // CRTP
 // A base class that takes care of 1:N type conversion, which maps the converted
 // op results (computed by the derived class) and materializes 1:N conversion.
@@ -119,9 +113,9 @@ class ConvertForOpTypes
     // We can not do clone as the number of result types after conversion
     // might be different.
     ForOp newOp = rewriter.create<ForOp>(
-        op.getLoc(), getSingleValue(adaptor.getLowerBound()),
-        getSingleValue(adaptor.getUpperBound()),
-        getSingleValue(adaptor.getStep()),
+        op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
+        llvm::getSingleElement(adaptor.getUpperBound()),
+        llvm::getSingleElement(adaptor.getStep()),
         flattenValues(adaptor.getInitArgs()));
 
     // Reserve whatever attributes in the original op.
@@ -149,7 +143,8 @@ class ConvertIfOpTypes
                                       TypeRange dstTypes) const {
 
     IfOp newOp = rewriter.create<IfOp>(
-        op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
+        op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
+        true);
     newOp->setAttrs(op->getAttrs());
 
     // We do not need the empty blocks created by rewriter.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 19335255fd492..e9471c1dbd0b7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
                  scf::ForOp target) {
   SmallVector<scf::ForOp, 8> res;
-  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
-    assert(loops.size() == 1);
-    res.push_back(loops[0]);
-  }
+  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
+    res.push_back(llvm::getSingleElement(loops));
   return res;
 }
 
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 66a2e45001781..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -38,7 +38,7 @@ struct AssumingOpInterface
     size_t resultNum = std::distance(op->getOpResults().begin(),
                                      llvm::find(op->getOpResults(), value));
     // TODO: Support multiple blocks.
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "expected exactly 1 block");
     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
@@ -49,7 +49,7 @@ struct AssumingOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
-    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+    assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "only 1 block supported");
     auto yieldOp = cast<shape::AssumingYieldOp>(
         assumingOp.getDoRegion().front().getTerminator());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 9e9fea76416b9..948ba60ac0bbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -12,12 +12,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
                              SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
@@ -200,7 +194,7 @@ class ExtractIterSpaceConverter
 
     // Construct the iteration space.
     SparseIterationSpace space(loc, rewriter,
-                               getSingleValue(adaptor.getTensor()), 0,
+                               llvm::getSingleElement(adaptor.getTensor()), 0,
                                op.getLvlRange(), adaptor.getParentIter());
 
     SmallVector<Value> result = space.toValues();
@@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value pos = adaptor.getIterator().back();
-    Value valBuf =
-        rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
+    Value valBuf = rewriter.create<ToValuesOp>(
+        loc, llvm::getSingleElement(adaptor.getTensor()));
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
     return success();
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20d46f7ca00c5..6a66ad24a87b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
   return result;
 }
 
-/// Assert that the given value range contains a single value and return it.
-static Value getSingleValue(ValueRange values) {
-  assert(values.size() == 1 && "expected single value");
-  return values.front();
-}
-
 /// Generates a load with proper `index` typing.
 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
   idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
                                ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/131460


More information about the Mlir-commits mailing list