[Mlir-commits] [mlir] [mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Builder (PR #68969)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Oct 18 02:04:38 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/68969

>From 081be10b57e31a785b2b8db84cb3447eababf36d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 13 Oct 2023 09:33:25 +0000
Subject: [PATCH 1/6] [mlir] Fix use-after-free bugs in
 {RankedTensorType|VectorType}::Builder

Previously, these would set their ArrayRef members to reference their
storage SmallVectors after a copy-on-write operation. This leads to a
use-after-free if the builder is copied and the original destroyed (as
the new builder would still reference the old SmallVector).

The VectorType::Builder also set the ArrayRef<bool> scalableDims member
to a temporary SmallVector when the provided scalableDims are empty.
This again lead to a use-after-free, and is unnecessary as
VectorType::get already handles being passed an empty scalableDims
array.

These bugs were in-part caught by UBSAN, see:
https://lab.llvm.org/buildbot/#/builders/5/builds/37355
---
 mlir/include/mlir/IR/BuiltinTypes.h | 43 ++++++++++++++++-------------
 1 file changed, 24 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 9df5548cd5d939c..13fbae90b68c6cb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -277,7 +277,7 @@ class RankedTensorType::Builder {
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
     storage.erase(storage.begin() + pos);
-    shape = {storage.data(), storage.size()};
+    shape = {};
     return *this;
   }
 
@@ -287,12 +287,17 @@ class RankedTensorType::Builder {
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
     storage.insert(storage.begin() + pos, val);
-    shape = {storage.data(), storage.size()};
+    shape = {};
     return *this;
   }
 
+  /// Returns the current shape.
+  ArrayRef<int64_t> getShape() {
+    return shape.empty() ? ArrayRef(storage) : shape;
+  }
+
   operator RankedTensorType() {
-    return RankedTensorType::get(shape, elementType, encoding);
+    return RankedTensorType::get(getShape(), elementType, encoding);
   }
 
 private:
@@ -319,20 +324,11 @@ class VectorType::Builder {
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
           unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
-      : shape(shape), elementType(elementType) {
-    if (scalableDims.empty())
-      scalableDims = SmallVector<bool>(shape.size(), false);
-    else
-      this->scalableDims = scalableDims;
-  }
+      : shape(shape), elementType(elementType), scalableDims(scalableDims) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
-    if (newIsScalableDim.empty())
-      scalableDims = SmallVector<bool>(shape.size(), false);
-    else
-      scalableDims = newIsScalableDim;
-
+    scalableDims = newIsScalableDim;
     shape = newShape;
     return *this;
   }
@@ -351,9 +347,8 @@ class VectorType::Builder {
       storageScalableDims.append(scalableDims.begin(), scalableDims.end());
     storage.erase(storage.begin() + pos);
     storageScalableDims.erase(storageScalableDims.begin() + pos);
-    shape = {storage.data(), storage.size()};
-    scalableDims =
-        ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
+    shape = {};
+    scalableDims = {};
     return *this;
   }
 
@@ -363,12 +358,22 @@ class VectorType::Builder {
       storage.append(shape.begin(), shape.end());
     assert(pos < storage.size() && "overflow");
     storage[pos] = val;
-    shape = {storage.data(), storage.size()};
+    shape = {};
     return *this;
   }
 
+  /// Returns the current shape.
+  ArrayRef<int64_t> getShape() {
+    return shape.empty() ? ArrayRef(storage) : shape;
+  }
+
+  /// Returns the current scalable dims.
+  ArrayRef<bool> getScalableDims() {
+    return scalableDims.empty() ? ArrayRef(storageScalableDims) : scalableDims;
+  }
+
   operator VectorType() {
-    return VectorType::get(shape, elementType, scalableDims);
+    return VectorType::get(getShape(), elementType, getScalableDims());
   }
 
 private:

>From 1ef6ebbc0afc57e4e5b6b424e3ffe428634d2930 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 16 Oct 2023 10:41:31 +0000
Subject: [PATCH 2/6] Fixup asserts and add some simple tests

---
 mlir/include/mlir/IR/BuiltinTypes.h  |  6 +--
 mlir/unittests/IR/ShapedTypeTest.cpp | 70 ++++++++++++++++++++++++++++
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 13fbae90b68c6cb..41ea6ff3185315b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -273,9 +273,9 @@ class RankedTensorType::Builder {
 
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
-    assert(pos < shape.size() && "overflow");
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
     storage.erase(storage.begin() + pos);
     shape = {};
     return *this;
@@ -283,9 +283,9 @@ class RankedTensorType::Builder {
 
   /// Insert a val into shape @pos.
   Builder &insertDim(int64_t val, unsigned pos) {
-    assert(pos <= shape.size() && "overflow");
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
+    assert(pos <= storage.size() && "overflow");
     storage.insert(storage.begin() + pos, val);
     shape = {};
     return *this;
@@ -340,9 +340,9 @@ class VectorType::Builder {
 
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
-    assert(pos < shape.size() && "overflow");
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
     if (storageScalableDims.empty())
       storageScalableDims.append(scalableDims.begin(), scalableDims.end());
     storage.erase(storage.begin() + pos);
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 82674fd3768b6cb..18247e11e52390c 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -131,4 +131,74 @@ TEST(ShapedTypeTest, CloneVector) {
             VectorType::get(vectorNewShape, vectorNewType));
 }
 
+TEST(ShapedTypeTest, VectorTypeBuilder) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  VectorType vectorType =
+      VectorType::get({2, 4, 8, 9, 1}, f32, {true, false, true, false, false});
+
+  {
+    // Drop some dims.
+    VectorType dropFrontTwoDims =
+        VectorType::Builder(vectorType).dropDim(0).dropDim(0);
+    ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
+    ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
+    ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
+              dropFrontTwoDims.getScalableDims());
+  }
+
+  {
+    // Set some dims.
+    VectorType setTwoDims =
+        VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
+    ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
+    ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
+    ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
+  }
+
+  {
+    // Test for bug from:
+    // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
+    // Constructs a temporary builder, modifies it, copies it to `builder`.
+    VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
+    VectorType newVectorType = VectorType(builder);
+    ASSERT_EQ(newVectorType.getDimSize(0), 16);
+  }
+}
+
+TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  RankedTensorType tensorType = RankedTensorType::get({2, 4, 8, 16, 32}, f32);
+
+  {
+    // Drop some dims.
+    RankedTensorType dropFrontTwoDims =
+        RankedTensorType::Builder(tensorType).dropDim(0).dropDim(0);
+    ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
+    ASSERT_EQ(tensorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
+  }
+
+  {
+    // Insert some dims.
+    RankedTensorType insertTwoDims =
+        RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
+    ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
+    ASSERT_EQ(insertTwoDims.getShape(),
+              ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
+  }
+
+  {
+    // Test for bug from:
+    // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
+    // Constructs a temporary builder, modifies it, copies it to `builder`.
+    RankedTensorType::Builder builder =
+        RankedTensorType::Builder(tensorType).dropDim(0);
+    RankedTensorType newTensorType = RankedTensorType(builder);
+    ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
+  }
+}
+
 } // namespace

>From 6a462fdb350ddf456ac1d07d9056770be097d917 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 16 Oct 2023 10:56:25 +0000
Subject: [PATCH 3/6] Add a few more tests

---
 mlir/include/mlir/IR/BuiltinTypes.h  |  2 +-
 mlir/unittests/IR/ShapedTypeTest.cpp | 22 ++++++++++++++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 41ea6ff3185315b..8efb62fd1ec6e7c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -323,7 +323,7 @@ class VectorType::Builder {
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
+          ArrayRef<bool> scalableDims = {})
       : shape(shape), elementType(elementType), scalableDims(scalableDims) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 18247e11e52390c..7a249122db38544 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -161,10 +161,30 @@ TEST(ShapedTypeTest, VectorTypeBuilder) {
     // Test for bug from:
     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
     // Constructs a temporary builder, modifies it, copies it to `builder`.
+    // This used to lead to a use-after-free. Running under sanitizers will
+    // catch any issues.
     VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
     VectorType newVectorType = VectorType(builder);
     ASSERT_EQ(newVectorType.getDimSize(0), 16);
   }
+
+  {
+    // Make builder from scratch (without scalable dims) -- this use to lead to
+    // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
+    // Running under sanitizers will catch any issues.
+    SmallVector<int64_t> shape{1, 2, 3, 4};
+    VectorType::Builder builder(shape, f32);
+    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
+  }
+
+  {
+    // Set vector shape (without scalable dims) -- this use to lead to
+    // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
+    // Running under sanitizers will catch any issues.
+    VectorType::Builder builder(vectorType);
+    builder.setShape({2, 2});
+    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef<int64_t>({2, 2}));
+  }
 }
 
 TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
@@ -194,6 +214,8 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
     // Test for bug from:
     // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
     // Constructs a temporary builder, modifies it, copies it to `builder`.
+    // This used to lead to a use-after-free. Running under sanitizers will
+    // catch any issues.
     RankedTensorType::Builder builder =
         RankedTensorType::Builder(tensorType).dropDim(0);
     RankedTensorType newTensorType = RankedTensorType(builder);

>From 1fb8e9aae81b107299acdc8d379d5ccfe91f4d26 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 16 Oct 2023 11:42:40 +0000
Subject: [PATCH 4/6] Fix some unsafe std::initializer_list list usages in
 tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Easy to mess up 🤦
---
 mlir/unittests/IR/ShapedTypeTest.cpp | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 7a249122db38544..618851a4e72667b 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -135,8 +135,9 @@ TEST(ShapedTypeTest, VectorTypeBuilder) {
   MLIRContext context;
   Type f32 = FloatType::getF32(&context);
 
-  VectorType vectorType =
-      VectorType::get({2, 4, 8, 9, 1}, f32, {true, false, true, false, false});
+  SmallVector<int64_t> shape{2, 4, 8, 9, 1};
+  SmallVector<bool> scalableDims{true, false, true, false, false};
+  VectorType vectorType = VectorType::get(shape, f32, scalableDims);
 
   {
     // Drop some dims.
@@ -182,8 +183,9 @@ TEST(ShapedTypeTest, VectorTypeBuilder) {
     // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
     // Running under sanitizers will catch any issues.
     VectorType::Builder builder(vectorType);
-    builder.setShape({2, 2});
-    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef<int64_t>({2, 2}));
+    SmallVector<int64_t> newShape{2, 2};
+    builder.setShape(newShape);
+    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
   }
 }
 
@@ -191,7 +193,8 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   MLIRContext context;
   Type f32 = FloatType::getF32(&context);
 
-  RankedTensorType tensorType = RankedTensorType::get({2, 4, 8, 16, 32}, f32);
+  SmallVector<int64_t> shape{2, 4, 8, 16, 32};
+  RankedTensorType tensorType = RankedTensorType::get(shape, f32);
 
   {
     // Drop some dims.

>From e7ca4a9eeca61814f2738d807457f1819448af39 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 17 Oct 2023 09:27:50 +0000
Subject: [PATCH 5/6] Move builder COW pattern to safe CopyOnWriteArrayRef<T>
 class.

This is a safer helper class that avoids the issues with the previous
manual approaches. It also cleans up the builder implementations.
---
 mlir/include/mlir/IR/BuiltinTypes.h | 125 ++++++++++++++++------------
 1 file changed, 72 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 8efb62fd1ec6e7c..69dd397eda791d8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -191,6 +191,60 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// CopyOnWriteArrayRef<T>
+//===----------------------------------------------------------------------===//
+
+// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
+// modification. This is for use in the mlir::<Type>::Builders.
+template <typename T>
+class CopyOnWriteArrayRef {
+public:
+  CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};
+
+  CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
+    nonOwning = array;
+    owningStorage = {};
+    return *this;
+  }
+
+  void insert(size_t index, T value) {
+    SmallVector<T> &vector = ensureCopy();
+    vector.insert(vector.begin() + index, value);
+  }
+
+  void erase(size_t index) {
+    SmallVector<T> &vector = ensureCopy();
+    vector.erase(vector.begin() + index);
+  }
+
+  void set(size_t index, T value) { ensureCopy()[index] = value; }
+
+  size_t size() const { return ArrayRef<T>(*this).size(); }
+
+  bool empty() const { return ArrayRef<T>(*this).empty(); }
+
+  operator ArrayRef<T>() const {
+    return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
+  }
+
+private:
+  SmallVector<T> &ensureCopy() {
+    // Empty non-owning storage signals the array has been copied to the owning
+    // storage (or both are empty). Note: `nonOwning` should never reference
+    // `owningStorage`. This can lead to dangling references if the
+    // CopyOnWriteArrayRef<T> is copied.
+    if (!nonOwning.empty()) {
+      owningStorage = SmallVector<T>(nonOwning);
+      nonOwning = {};
+    }
+    return owningStorage;
+  }
+
+  ArrayRef<T> nonOwning;
+  SmallVector<T> owningStorage;
+};
+
 //===----------------------------------------------------------------------===//
 // MemRefType
 //===----------------------------------------------------------------------===//
@@ -273,37 +327,24 @@ class RankedTensorType::Builder {
 
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
-    if (storage.empty())
-      storage.append(shape.begin(), shape.end());
-    assert(pos < storage.size() && "overflow");
-    storage.erase(storage.begin() + pos);
-    shape = {};
+    assert(pos < shape.size() && "overflow");
+    shape.erase(pos);
     return *this;
   }
 
   /// Insert a val into shape @pos.
   Builder &insertDim(int64_t val, unsigned pos) {
-    if (storage.empty())
-      storage.append(shape.begin(), shape.end());
-    assert(pos <= storage.size() && "overflow");
-    storage.insert(storage.begin() + pos, val);
-    shape = {};
+    assert(pos <= shape.size() && "overflow");
+    shape.insert(pos, val);
     return *this;
   }
 
-  /// Returns the current shape.
-  ArrayRef<int64_t> getShape() {
-    return shape.empty() ? ArrayRef(storage) : shape;
-  }
-
   operator RankedTensorType() {
-    return RankedTensorType::get(getShape(), elementType, encoding);
+    return RankedTensorType::get(shape, elementType, encoding);
   }
 
 private:
-  ArrayRef<int64_t> shape;
-  // Owning shape data for copy-on-write operations.
-  SmallVector<int64_t> storage;
+  CopyOnWriteArrayRef<int64_t> shape;
   Type elementType;
   Attribute encoding;
 };
@@ -318,18 +359,18 @@ class VectorType::Builder {
 public:
   /// Build from another VectorType.
   explicit Builder(VectorType other)
-      : shape(other.getShape()), elementType(other.getElementType()),
+      : elementType(other.getElementType()), shape(other.getShape()),
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
           ArrayRef<bool> scalableDims = {})
-      : shape(shape), elementType(elementType), scalableDims(scalableDims) {}
+      : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
-    scalableDims = newIsScalableDim;
     shape = newShape;
+    scalableDims = newIsScalableDim;
     return *this;
   }
 
@@ -340,50 +381,28 @@ class VectorType::Builder {
 
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
-    if (storage.empty())
-      storage.append(shape.begin(), shape.end());
-    assert(pos < storage.size() && "overflow");
-    if (storageScalableDims.empty())
-      storageScalableDims.append(scalableDims.begin(), scalableDims.end());
-    storage.erase(storage.begin() + pos);
-    storageScalableDims.erase(storageScalableDims.begin() + pos);
-    shape = {};
-    scalableDims = {};
+    assert(pos < shape.size() && "overflow");
+    shape.erase(pos);
+    if (!scalableDims.empty())
+      scalableDims.erase(pos);
     return *this;
   }
 
   /// Set a dim in shape @pos to val.
   Builder &setDim(unsigned pos, int64_t val) {
-    if (storage.empty())
-      storage.append(shape.begin(), shape.end());
-    assert(pos < storage.size() && "overflow");
-    storage[pos] = val;
-    shape = {};
+    assert(pos < shape.size() && "overflow");
+    shape.set(pos, val);
     return *this;
   }
 
-  /// Returns the current shape.
-  ArrayRef<int64_t> getShape() {
-    return shape.empty() ? ArrayRef(storage) : shape;
-  }
-
-  /// Returns the current scalable dims.
-  ArrayRef<bool> getScalableDims() {
-    return scalableDims.empty() ? ArrayRef(storageScalableDims) : scalableDims;
-  }
-
   operator VectorType() {
-    return VectorType::get(getShape(), elementType, getScalableDims());
+    return VectorType::get(shape, elementType, scalableDims);
   }
 
 private:
-  ArrayRef<int64_t> shape;
-  // Owning shape data for copy-on-write operations.
-  SmallVector<int64_t> storage;
   Type elementType;
-  ArrayRef<bool> scalableDims;
-  // Owning scalableDims data for copy-on-write operations.
-  SmallVector<bool> storageScalableDims;
+  CopyOnWriteArrayRef<int64_t> shape;
+  CopyOnWriteArrayRef<bool> scalableDims;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of

>From 1e6248d3307cf0ef190992668e4075b8fa7a68d0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 18 Oct 2023 09:02:35 +0000
Subject: [PATCH 6/6] Move CopyOnWriteArrayRef to mlir/Support/ADTExtras.h

+ Some minor tweaks
---
 mlir/include/mlir/IR/BuiltinTypes.h   | 55 +-----------------
 mlir/include/mlir/Support/ADTExtras.h | 82 +++++++++++++++++++++++++++
 mlir/unittests/IR/ShapedTypeTest.cpp  |  4 +-
 3 files changed, 85 insertions(+), 56 deletions(-)
 create mode 100644 mlir/include/mlir/Support/ADTExtras.h

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 69dd397eda791d8..92ce053ad5c829b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/Support/ADTExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -191,60 +192,6 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
 
 namespace mlir {
 
-//===----------------------------------------------------------------------===//
-// CopyOnWriteArrayRef<T>
-//===----------------------------------------------------------------------===//
-
-// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
-// modification. This is for use in the mlir::<Type>::Builders.
-template <typename T>
-class CopyOnWriteArrayRef {
-public:
-  CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};
-
-  CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
-    nonOwning = array;
-    owningStorage = {};
-    return *this;
-  }
-
-  void insert(size_t index, T value) {
-    SmallVector<T> &vector = ensureCopy();
-    vector.insert(vector.begin() + index, value);
-  }
-
-  void erase(size_t index) {
-    SmallVector<T> &vector = ensureCopy();
-    vector.erase(vector.begin() + index);
-  }
-
-  void set(size_t index, T value) { ensureCopy()[index] = value; }
-
-  size_t size() const { return ArrayRef<T>(*this).size(); }
-
-  bool empty() const { return ArrayRef<T>(*this).empty(); }
-
-  operator ArrayRef<T>() const {
-    return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
-  }
-
-private:
-  SmallVector<T> &ensureCopy() {
-    // Empty non-owning storage signals the array has been copied to the owning
-    // storage (or both are empty). Note: `nonOwning` should never reference
-    // `owningStorage`. This can lead to dangling references if the
-    // CopyOnWriteArrayRef<T> is copied.
-    if (!nonOwning.empty()) {
-      owningStorage = SmallVector<T>(nonOwning);
-      nonOwning = {};
-    }
-    return owningStorage;
-  }
-
-  ArrayRef<T> nonOwning;
-  SmallVector<T> owningStorage;
-};
-
 //===----------------------------------------------------------------------===//
 // MemRefType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Support/ADTExtras.h b/mlir/include/mlir/Support/ADTExtras.h
new file mode 100644
index 000000000000000..1e4708f8f7d3f9e
--- /dev/null
+++ b/mlir/include/mlir/Support/ADTExtras.h
@@ -0,0 +1,82 @@
+//===- ADTExtras.h - Extra ADTs for use in MLIR -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_ADTEXTRAS_H
+#define MLIR_SUPPORT_ADTEXTRAS_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CopyOnWriteArrayRef<T>
+//===----------------------------------------------------------------------===//
+
+// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
+// modification. This is for use in the mlir::<Type>::Builders.
+template <typename T>
+class CopyOnWriteArrayRef {
+public:
+  CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};
+
+  CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
+    nonOwning = array;
+    owningStorage = {};
+    return *this;
+  }
+
+  void insert(size_t index, T value) {
+    SmallVector<T> &vector = ensureCopy();
+    vector.insert(vector.begin() + index, value);
+  }
+
+  void erase(size_t index) {
+    // Note: A copy can be avoided when just dropping the front/back dims.
+    if (isNonOwning() && index == 0) {
+      nonOwning = nonOwning.drop_front();
+    } else if (isNonOwning() && index == size() - 1) {
+      nonOwning = nonOwning.drop_back();
+    } else {
+      SmallVector<T> &vector = ensureCopy();
+      vector.erase(vector.begin() + index);
+    }
+  }
+
+  void set(size_t index, T value) { ensureCopy()[index] = value; }
+
+  size_t size() const { return ArrayRef<T>(*this).size(); }
+
+  bool empty() const { return ArrayRef<T>(*this).empty(); }
+
+  operator ArrayRef<T>() const {
+    return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
+  }
+
+private:
+  bool isNonOwning() const { return !nonOwning.empty(); }
+
+  SmallVector<T> &ensureCopy() {
+    // Empty non-owning storage signals the array has been copied to the owning
+    // storage (or both are empty). Note: `nonOwning` should never reference
+    // `owningStorage`. This can lead to dangling references if the
+    // CopyOnWriteArrayRef<T> is copied.
+    if (isNonOwning()) {
+      owningStorage = SmallVector<T>(nonOwning);
+      nonOwning = {};
+    }
+    return owningStorage;
+  }
+
+  ArrayRef<T> nonOwning;
+  SmallVector<T> owningStorage;
+};
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 618851a4e72667b..61264bc523648cf 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -199,9 +199,9 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   {
     // Drop some dims.
     RankedTensorType dropFrontTwoDims =
-        RankedTensorType::Builder(tensorType).dropDim(0).dropDim(0);
+        RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
     ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
-    ASSERT_EQ(tensorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
+    ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
   }
 
   {



More information about the Mlir-commits mailing list