[Mlir-commits] [mlir] [mlir] Make single value `ValueRange`s memory safer (PR #121996)

Markus Böck llvmlistbot at llvm.org
Tue Jan 7 13:30:01 PST 2025


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/121996

A very common mistake users (and yours truly) make when using `ValueRange`s is assigning a temporary `Value` to it. Example:
```cpp
ValueRange values = op.getOperand();
apiThatUsesValueRange(values);
```

The issue is caused by the implicit `const Value&` constructor: As per C++ rules a const reference can be constructed from a temporary and the address of it taken. After the statement, the temporary goes out of scope and `stack-use-after-free` error occurs.

This PR fixes that issue by making `ValueRange` capable of owning a single `Value` instance for that case specifically. While technically a departure from the other owner types that are non-owning, I'd argue that this behavior is more intuitive for the majority of users that usually don't need to care about the lifetime of `Value` instances.

`TypeRange` has similarly been adopted to accept a single `Type` instance to implement `getTypes`.

>From 74eac95870e5516bb0bdf6013cc35de4f7dc1800 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Tue, 7 Jan 2025 22:27:57 +0100
Subject: [PATCH] [mlir] Make single value `ValueRange`s memory safer

A very common mistake users (and yours truly) make when using `ValueRange`s is assigning a temporary `Value` to it.
Example:
```cpp
ValueRange values = op.getOperand();
apiThatUsesValueRange(values);
```

The issue is caused by the implicit `const Value&` constructor: As per C++ rules a const reference can be constructed from a temporary and the address of it taken.
After the statement, the temporary goes out of scope and `stack-use-after-free` error occurs.

This PR fixes that issue by making `ValueRange` capable of owning a single `Value` instance for that case specifically. While technically a departure from the other owner types that are non-owning, I'd argue that this behavior is more intuitive for the majority of users that usually don't need to care about the lifetime of `Value` instances.

`TypeRange` has similarly been adopted to accept a single `Type` instance to implement `getTypes`.
---
 mlir/include/mlir/IR/TypeRange.h           | 16 ++++++++++------
 mlir/include/mlir/IR/ValueRange.h          | 11 ++++++-----
 mlir/lib/IR/OperationSupport.cpp           | 13 +++++++++++++
 mlir/lib/IR/TypeRange.cpp                  | 15 +++++++++++++++
 mlir/unittests/IR/OperationSupportTest.cpp | 17 +++++++++++++++++
 5 files changed, 61 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index 99fabab334f922..3a255583e28583 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -31,9 +31,9 @@ namespace mlir {
 /// parameter.
 class TypeRange : public llvm::detail::indexed_accessor_range_base<
                       TypeRange,
-                      llvm::PointerUnion<const Value *, const Type *,
-                                         OpOperand *, detail::OpResultImpl *>,
-                      Type, Type, Type> {
+          llvm::PointerUnion<const Value *, const Type *, OpOperand *,
+                             detail::OpResultImpl *, Type>,
+          Type, Type, Type> {
 public:
   using RangeBaseT::RangeBaseT;
   TypeRange(ArrayRef<Type> types = std::nullopt);
@@ -44,8 +44,11 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
   TypeRange(ValueTypeRange<ValueRangeT> values)
       : TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(),
                                          values.end().getCurrent()))) {}
-  template <typename Arg, typename = std::enable_if_t<std::is_constructible<
-                              ArrayRef<Type>, Arg>::value>>
+
+  TypeRange(Type type) : TypeRange(type, /*count=*/1) {}
+  template <typename Arg, typename = std::enable_if_t<
+                              std::is_constructible_v<ArrayRef<Type>, Arg> &&
+                              !std::is_constructible_v<Type, Arg>>>
   TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
   TypeRange(std::initializer_list<Type> types)
       : TypeRange(ArrayRef<Type>(types)) {}
@@ -56,8 +59,9 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
   /// * A pointer to the first element of an array of types.
   /// * A pointer to the first element of an array of operands.
   /// * A pointer to the first element of an array of results.
+  /// * A single 'Type' instance.
   using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
-                                    detail::OpResultImpl *>;
+                                    detail::OpResultImpl *, Type>;
 
   /// See `llvm::detail::indexed_accessor_range_base` for details.
   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 4b421c08d8418e..f878abd63de35f 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -377,13 +377,14 @@ class ResultRange::UseIterator final
 class ValueRange final
     : public llvm::detail::indexed_accessor_range_base<
           ValueRange,
-          PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
-          Value, Value, Value> {
+                             PointerUnion<const Value *, OpOperand *,
+                                          detail::OpResultImpl *, Value>,
+                             Value, Value, Value> {
 public:
   /// The type representing the owner of a ValueRange. This is either a list of
-  /// values, operands, or results.
+  /// values, operands, or results or a single value.
   using OwnerT =
-      PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
+      PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *, Value>;
 
   using RangeBaseT::RangeBaseT;
 
@@ -392,7 +393,7 @@ class ValueRange final
                 std::is_constructible<ArrayRef<Value>, Arg>::value &&
                 !std::is_convertible<Arg, Value>::value>>
   ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
-  ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
+  ValueRange(Value value) : ValueRange(value, /*count=*/1) {}
   ValueRange(const std::initializer_list<Value> &values)
       : ValueRange(ArrayRef<Value>(values)) {}
   ValueRange(iterator_range<OperandRange::iterator> values)
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 957195202d78d2..803fcd8d18fbd5 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -653,6 +653,15 @@ ValueRange::ValueRange(ResultRange values)
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
                                            ptrdiff_t index) {
+  if (llvm::isa_and_nonnull<Value>(owner)) {
+    // Prevent out-of-bounds indexing for single values.
+    // Note that we do allow an index of 1 as is required by 'slice'ing that
+    // returns an empty range. This also matches the usual rules of C++ of being
+    // allowed to index past the last element of an array.
+    assert(index <= 1 && "out-of-bound offset into single-value 'ValueRange'");
+    // Return nullptr to quickly cause segmentation faults on misuse.
+    return index == 0 ? owner : nullptr;
+  }
   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
     return {value + index};
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
@@ -661,6 +670,10 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
 }
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
+  if (auto value = llvm::dyn_cast_if_present<Value>(owner)) {
+    assert(index == 0 && "cannot offset into single-value 'ValueRange'");
+    return value;
+  }
   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
     return value[index];
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index f8878303727d4f..7e5f99c884512e 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -31,12 +31,23 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
     this->base = result;
   else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     this->base = operand;
+  else if (auto value = llvm::dyn_cast_if_present<Value>(owner))
+    this->base = value.getType();
   else
     this->base = cast<const Value *>(owner);
 }
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
+  if (llvm::isa_and_nonnull<Type>(object)) {
+    // Prevent out-of-bounds indexing for single values.
+    // Note that we do allow an index of 1 as is required by 'slice'ing that
+    // returns an empty range. This also matches the usual rules of C++ of being
+    // allowed to index past the last element of an array.
+    assert(index <= 1 && "out-of-bound offset into single-value 'ValueRange'");
+    // Return nullptr to quickly cause segmentation faults on misuse.
+    return index == 0 ? object : nullptr;
+  }
   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
     return {value + index};
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
@@ -48,6 +59,10 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
+  if (auto type = llvm::dyn_cast_if_present<Type>(object)) {
+    assert(index == 0 && "cannot offset into single-value 'TypeRange'");
+    return type;
+  }
   if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
     return (value + index)->getType();
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f94dc784458077..2a1b8d2ef7f55b 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -313,4 +313,21 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
   op2->destroy();
 }
 
+TEST(ValueRangeTest, ValueConstructable) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  Operation *useOp =
+      createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
+  // Valid construction despite a temporary 'OpResult'.
+  ValueRange operands = useOp->getResult(0);
+
+  useOp->setOperands(operands);
+  EXPECT_EQ(useOp->getNumOperands(), 1u);
+  EXPECT_EQ(useOp->getOperand(0), useOp->getResult(0));
+
+  useOp->dropAllUses();
+  useOp->destroy();
+}
+
 } // namespace



More information about the Mlir-commits mailing list