[Mlir-commits] [mlir] [mlir][Value] Add getNumUses, hasNUses, and hasNUsesOrMore to Value (PR #142084)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 29 21:11:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Michael Maitland (michaelmaitland)
<details>
<summary>Changes</summary>
We already have hasOneUse. Like llvm::Value we provide helper methods to query the number of uses of a Value. Add unittests for Value, because that was missing.
---
Full diff: https://github.com/llvm/llvm-project/pull/142084.diff
8 Files Affected:
- (modified) mlir/docs/Tutorials/UnderstandingTheIRStructure.md (+3-7)
- (modified) mlir/include/mlir/IR/Value.h (+14)
- (modified) mlir/lib/Bytecode/Reader/BytecodeReader.cpp (+1-2)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+1-1)
- (modified) mlir/lib/IR/Value.cpp (+12)
- (modified) mlir/test/lib/IR/TestPrintDefUse.cpp (+3-7)
- (modified) mlir/unittests/IR/CMakeLists.txt (+1)
- (added) mlir/unittests/IR/ValueTest.cpp (+90)
``````````diff
diff --git a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md
index 595d6949a03f3..30b50cb09490d 100644
--- a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md
+++ b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md
@@ -257,14 +257,10 @@ results and print informations about them:
llvm::outs() << " has no uses\n";
continue;
}
- if (result.hasOneUse()) {
+ if (result.hasOneUse())
llvm::outs() << " has a single use: ";
- } else {
- llvm::outs() << " has "
- << std::distance(result.getUses().begin(),
- result.getUses().end())
- << " uses:\n";
- }
+ else
+ llvm::outs() << " has " << result.getNumUses() << " uses:\n";
for (Operation *userOp : result.getUsers()) {
llvm::outs() << " - " << userOp->getName() << "\n";
}
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index d54e3c0ad26dd..4d6d89fa69a07 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -187,9 +187,23 @@ class Value {
/// Returns a range of all uses, which is useful for iterating over all uses.
use_range getUses() const { return {use_begin(), use_end()}; }
+ /// This method computes the number of uses of this Value.
+ ///
+ /// This is a linear time operation. Use hasOneUse, hasNUses, or
+ /// hasNUsesOrMore to check for specific values.
+ unsigned getNumUses() const;
+
/// Returns true if this value has exactly one use.
bool hasOneUse() const { return impl->hasOneUse(); }
+ /// Return true if this Value has exactly n uses.
+ bool hasNUses(unsigned n) const;
+
+ /// Return true if this value has n uses or more.
+ ///
+ /// This is logically equivalent to getNumUses() >= N.
+ bool hasNUsesOrMore(unsigned n) const;
+
/// Returns true if this value has no uses.
bool use_empty() const { return impl->use_empty(); }
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1052946d4550b..44458d010c6c8 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1993,8 +1993,7 @@ LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
UseListOrderStorage customOrder =
valueToUseListMap.at(value.getAsOpaquePointer());
SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
- uint64_t numUses =
- std::distance(value.getUses().begin(), value.getUses().end());
+ uint64_t numUses = value.getNumUses();
// If the encoding was a pair of indices `(src, dst)` for every permutation,
// reconstruct the shuffle vector for every use. Initialize the shuffle vector
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 748379ea671be..5a0b8a058dd65 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1787,7 +1787,7 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
- if (iv.getUses().begin() == iv.getUses().end())
+ if (iv.hasNUses(0))
continue;
auto numIterations = constantTripCount(lb, ub, step);
if (!numIterations.has_value() || numIterations.value() != 1) {
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 178765353cc10..7b3a9462a0917 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -51,6 +51,18 @@ Block *Value::getParentBlock() {
return llvm::cast<BlockArgument>(*this).getOwner();
}
+unsigned Value::getNumUses() const {
+ return (unsigned)std::distance(use_begin(), use_end());
+}
+
+bool Value::hasNUses(unsigned n) const {
+ return hasNItems(use_begin(), use_end(), n);
+}
+
+bool Value::hasNUsesOrMore(unsigned n) const {
+ return hasNItemsOrMore(use_begin(), use_end(), n);
+}
+
//===----------------------------------------------------------------------===//
// Value::UseLists
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestPrintDefUse.cpp b/mlir/test/lib/IR/TestPrintDefUse.cpp
index 5d489a342f57d..b983366fc16dd 100644
--- a/mlir/test/lib/IR/TestPrintDefUse.cpp
+++ b/mlir/test/lib/IR/TestPrintDefUse.cpp
@@ -49,14 +49,10 @@ struct TestPrintDefUsePass
llvm::outs() << " has no uses\n";
continue;
}
- if (result.hasOneUse()) {
+ if (result.hasOneUse())
llvm::outs() << " has a single use: ";
- } else {
- llvm::outs() << " has "
- << std::distance(result.getUses().begin(),
- result.getUses().end())
- << " uses:\n";
- }
+ else
+ llvm::outs() << " has " << result.getNumUses() << " uses:\n";
for (Operation *userOp : result.getUsers()) {
llvm::outs() << " - " << userOp->getName() << "\n";
}
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 821ff7d14dabd..9ab6029c3480d 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_unittest(MLIRIRTests
TypeTest.cpp
TypeAttrNamesTest.cpp
OpPropertiesTest.cpp
+ ValueTest.cpp
DEPENDS
MLIRTestInterfaceIncGen
diff --git a/mlir/unittests/IR/ValueTest.cpp b/mlir/unittests/IR/ValueTest.cpp
new file mode 100644
index 0000000000000..e31d9f32bc1d1
--- /dev/null
+++ b/mlir/unittests/IR/ValueTest.cpp
@@ -0,0 +1,90 @@
+//===- mlir/unittest/IR/ValueTest.cpp - Value unit tests ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Value.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static Operation *createOp(MLIRContext *context,
+ ArrayRef<Value> operands = std::nullopt,
+ ArrayRef<Type> resultTypes = std::nullopt,
+ unsigned int numRegions = 0) {
+ context->allowUnregisteredDialects();
+ return Operation::create(
+ UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes,
+ operands, std::nullopt, nullptr, std::nullopt, numRegions);
+}
+
+namespace {
+
+TEST(ValueTest, getNumUses) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *op0 =
+ createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
+
+ Value v0 = op0->getResult(0);
+ EXPECT_EQ(v0.getNumUses(), (unsigned)0);
+
+ createOp(&context, {v0}, builder.getIntegerType(16));
+ EXPECT_EQ(v0.getNumUses(), (unsigned)1);
+
+ createOp(&context, {v0, v0}, builder.getIntegerType(16));
+ EXPECT_EQ(v0.getNumUses(), (unsigned)3);
+}
+
+TEST(ValueTest, hasNUses) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *op =
+ createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
+ Value v0 = op->getResult(0);
+ EXPECT_TRUE(v0.hasNUses(0));
+ EXPECT_FALSE(v0.hasNUses(1));
+
+ createOp(&context, {v0}, builder.getIntegerType(16));
+ EXPECT_FALSE(v0.hasNUses(0));
+ EXPECT_TRUE(v0.hasNUses(1));
+
+ createOp(&context, {v0, v0}, builder.getIntegerType(16));
+ EXPECT_FALSE(v0.hasNUses(0));
+ EXPECT_FALSE(v0.hasNUses(1));
+ EXPECT_TRUE(v0.hasNUses(3));
+}
+
+TEST(ValueTest, hasNUsesOrMore) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *op =
+ createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
+ Value v0 = op->getResult(0);
+ EXPECT_TRUE(v0.hasNUsesOrMore(0));
+ EXPECT_FALSE(v0.hasNUsesOrMore(1));
+
+ createOp(&context, {v0}, builder.getIntegerType(16));
+ EXPECT_TRUE(v0.hasNUsesOrMore(0));
+ EXPECT_TRUE(v0.hasNUsesOrMore(1));
+ EXPECT_FALSE(v0.hasNUsesOrMore(2));
+
+ createOp(&context, {v0, v0}, builder.getIntegerType(16));
+ EXPECT_TRUE(v0.hasNUsesOrMore(0));
+ EXPECT_TRUE(v0.hasNUsesOrMore(1));
+ EXPECT_TRUE(v0.hasNUsesOrMore(3));
+ EXPECT_FALSE(v0.hasNUsesOrMore(4));
+}
+
+} // end anonymous namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/142084
More information about the Mlir-commits
mailing list