[Mlir-commits] [mlir] [ADT] Support appending initializer list (PR #69891)

Jakub Kuderski llvmlistbot at llvm.org
Sun Oct 22 19:01:03 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/69891

>From 45d0f07ed353177c1919c91a9d17ee1b8a79fafc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 22 Oct 2023 20:43:36 -0400
Subject: [PATCH 1/2] [ADT] Support appending initializer list

This is so that we can append multiple values at once without having to
create a temporary array or repetedly call `push_back`.

Use the new overload of `append_range` to clean up the SPIR-V serializer
code. (NFC)
---
 llvm/include/llvm/ADT/STLExtras.h             | 12 +++-
 llvm/unittests/ADT/STLExtrasTest.cpp          |  9 +++
 .../Target/SPIRV/Serialization/Serializer.cpp | 66 +++++++++----------
 3 files changed, 52 insertions(+), 35 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index d0b79fa91c03130..3f0e9d4c7f8be4d 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2029,14 +2029,22 @@ void erase_value(Container &C, ValueType V) {
   C.erase(std::remove(C.begin(), C.end(), V), C.end());
 }
 
-/// Wrapper function to append a range to a container.
+/// Wrapper function to append range `R` to container `C`.
 ///
 /// C.insert(C.end(), R.begin(), R.end());
 template <typename Container, typename Range>
-inline void append_range(Container &C, Range &&R) {
+void append_range(Container &C, Range &&R) {
   C.insert(C.end(), adl_begin(R), adl_end(R));
 }
 
+/// Appends all elements in the initializer list `Values` to the container `C`.
+/// This can be used as a replacement for repeated calls to `.push_back(X)`.
+/// Note that all values passed in the initializer list are copied.
+template <typename Container, typename T>
+void append_range(Container &C, std::initializer_list<T> Values) {
+  append_range<Container, std::initializer_list<T>>(C, std::move(Values));
+}
+
 /// Given a sequence container Cont, replace the range [ContIt, ContEnd) with
 /// the range [ValIt, ValEnd) (which is not from the same container).
 template<typename Container, typename RandomAccessIterator>
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index c34760d83874daf..8f63789633cac51 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -541,6 +541,15 @@ TEST(STLExtrasTest, AppendRange) {
   EXPECT_THAT(Str, ElementsAre('a', 'b', 'c', '\0', 'd', 'e', 'f', '\0'));
 }
 
+TEST(STLExtrasTest, AppendRangeInitializerList) {
+  std::vector<int> V = {1, 2};
+  append_range(V, {3});
+  EXPECT_THAT(V, ElementsAre(1, 2, 3));
+
+  append_range(V, {4, 5});
+  EXPECT_THAT(V, ElementsAre(1, 2, 3, 4, 5));
+}
+
 TEST(STLExtrasTest, ADLTest) {
   some_namespace::some_struct s{{1, 2, 3, 4, 5}, ""};
   some_namespace::some_struct s2{{2, 4, 6, 8, 10}, ""};
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index dad085e21b42727..46c38964f9a6019 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
@@ -443,13 +444,13 @@ LogicalResult Serializer::prepareBasicType(
     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
       return failure();
 
-    operands.push_back(sampledTypeID);
-    operands.push_back(static_cast<uint32_t>(imageType.getDim()));
-    operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
-    operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
-    operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
-    operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
-    operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
+    llvm::append_range(
+        operands, {sampledTypeID, static_cast<uint32_t>(imageType.getDim()),
+                   static_cast<uint32_t>(imageType.getDepthInfo()),
+                   static_cast<uint32_t>(imageType.getArrayedInfo()),
+                   static_cast<uint32_t>(imageType.getSamplingInfo()),
+                   static_cast<uint32_t>(imageType.getSamplerUseInfo()),
+                   static_cast<uint32_t>(imageType.getImageFormat())});
     return success();
   }
 
@@ -605,13 +606,13 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    operands.push_back(elementTypeID);
-    operands.push_back(
-        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
-    operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
-    operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
-    operands.push_back(
-        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
+    llvm::append_range(
+        operands,
+        {elementTypeID,
+         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+         getConstantOp(cooperativeMatrixType.getRows()),
+         getConstantOp(cooperativeMatrixType.getColumns()),
+         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse()))});
     return success();
   }
 
@@ -627,11 +628,12 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    operands.push_back(elementTypeID);
-    operands.push_back(
-        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
-    operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
-    operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
+    llvm::append_range(
+        operands,
+        {elementTypeID,
+         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+         getConstantOp(cooperativeMatrixType.getRows()),
+         getConstantOp(cooperativeMatrixType.getColumns())});
     return success();
   }
 
@@ -646,13 +648,13 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    operands.push_back(elementTypeID);
-    operands.push_back(getConstantOp(jointMatrixType.getRows()));
-    operands.push_back(getConstantOp(jointMatrixType.getColumns()));
-    operands.push_back(getConstantOp(
-        static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
-    operands.push_back(
-        getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+    llvm::append_range(
+        operands,
+        {elementTypeID, getConstantOp(jointMatrixType.getRows()),
+         getConstantOp(jointMatrixType.getColumns()),
+         getConstantOp(
+             static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
+         getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope()))});
     return success();
   }
 
@@ -663,8 +665,7 @@ LogicalResult Serializer::prepareBasicType(
       return failure();
     }
     typeEnum = spirv::Opcode::OpTypeMatrix;
-    operands.push_back(elementTypeID);
-    operands.push_back(matrixType.getNumColumns());
+    llvm::append_range(operands, {elementTypeID, matrixType.getNumColumns()});
     return success();
   }
 
@@ -1261,11 +1262,10 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
                                          spirv::Decoration decoration,
                                          ArrayRef<uint32_t> params) {
   uint32_t wordCount = 3 + params.size();
-  decorations.push_back(
-      spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
-  decorations.push_back(target);
-  decorations.push_back(static_cast<uint32_t>(decoration));
-  decorations.append(params.begin(), params.end());
+  llvm::append_range(decorations, {spirv::getPrefixedOpcode(
+                                       wordCount, spirv::Opcode::OpDecorate),
+                                   target, static_cast<uint32_t>(decoration)});
+  llvm::append_range(decorations, params);
   return success();
 }
 

>From 99ac7759295fa5e6f4513c03604fe0f2d259f3e6 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 22 Oct 2023 22:00:17 -0400
Subject: [PATCH 2/2] Do not initializer list, add new function instead

---
 llvm/include/llvm/ADT/STLExtras.h             | 12 ++--
 llvm/unittests/ADT/STLExtrasTest.cpp          | 31 +++++++---
 .../Target/SPIRV/Serialization/Serializer.cpp | 59 +++++++++----------
 3 files changed, 58 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 3f0e9d4c7f8be4d..1923072960c7008 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2037,12 +2037,12 @@ void append_range(Container &C, Range &&R) {
   C.insert(C.end(), adl_begin(R), adl_end(R));
 }
 
-/// Appends all elements in the initializer list `Values` to the container `C`.
-/// This can be used as a replacement for repeated calls to `.push_back(X)`.
-/// Note that all values passed in the initializer list are copied.
-template <typename Container, typename T>
-void append_range(Container &C, std::initializer_list<T> Values) {
-  append_range<Container, std::initializer_list<T>>(C, std::move(Values));
+/// Appends all `Values` to container `C`.
+template <typename Container, typename... Args>
+void append_values(Container &C, Args &&...Values) {
+  C.reserve(range_size(C) + sizeof...(Args));
+  // Append all values one by one.
+  ((void)C.insert(C.end(), std::forward<Args>(Values)), ...);
 }
 
 /// Given a sequence container Cont, replace the range [ContIt, ContEnd) with
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 8f63789633cac51..7db339e4ef31cdc 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -18,12 +18,14 @@
 #include <list>
 #include <tuple>
 #include <type_traits>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
 using namespace llvm;
 
 using testing::ElementsAre;
+using testing::UnorderedElementsAre;
 
 namespace {
 
@@ -541,13 +543,28 @@ TEST(STLExtrasTest, AppendRange) {
   EXPECT_THAT(Str, ElementsAre('a', 'b', 'c', '\0', 'd', 'e', 'f', '\0'));
 }
 
-TEST(STLExtrasTest, AppendRangeInitializerList) {
-  std::vector<int> V = {1, 2};
-  append_range(V, {3});
-  EXPECT_THAT(V, ElementsAre(1, 2, 3));
-
-  append_range(V, {4, 5});
-  EXPECT_THAT(V, ElementsAre(1, 2, 3, 4, 5));
+TEST(STLExtrasTest, AppendValues) {
+  std::vector<int> Vals = {1, 2};
+  append_values(Vals, 3);
+  EXPECT_THAT(Vals, ElementsAre(1, 2, 3));
+
+  append_values(Vals, 4, 5);
+  EXPECT_THAT(Vals, ElementsAre(1, 2, 3, 4, 5));
+
+  std::vector<StringRef> Strs;
+  std::string A = "A";
+  std::string B = "B";
+  std::string C = "C";
+  append_values(Strs, A, B);
+  EXPECT_THAT(Strs, ElementsAre(A, B));
+  append_values(Strs, C);
+  EXPECT_THAT(Strs, ElementsAre(A, B, C));
+
+  std::unordered_set<int> Set;
+  append_values(Set, 1, 2);
+  EXPECT_THAT(Set, UnorderedElementsAre(1, 2));
+  append_values(Set, 3, 1);
+  EXPECT_THAT(Set, UnorderedElementsAre(1, 2, 3));
 }
 
 TEST(STLExtrasTest, ADLTest) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 46c38964f9a6019..22fcc4939317be9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -444,13 +444,13 @@ LogicalResult Serializer::prepareBasicType(
     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
       return failure();
 
-    llvm::append_range(
-        operands, {sampledTypeID, static_cast<uint32_t>(imageType.getDim()),
-                   static_cast<uint32_t>(imageType.getDepthInfo()),
-                   static_cast<uint32_t>(imageType.getArrayedInfo()),
-                   static_cast<uint32_t>(imageType.getSamplingInfo()),
-                   static_cast<uint32_t>(imageType.getSamplerUseInfo()),
-                   static_cast<uint32_t>(imageType.getImageFormat())});
+    llvm::append_values(operands, sampledTypeID,
+                        static_cast<uint32_t>(imageType.getDim()),
+                        static_cast<uint32_t>(imageType.getDepthInfo()),
+                        static_cast<uint32_t>(imageType.getArrayedInfo()),
+                        static_cast<uint32_t>(imageType.getSamplingInfo()),
+                        static_cast<uint32_t>(imageType.getSamplerUseInfo()),
+                        static_cast<uint32_t>(imageType.getImageFormat()));
     return success();
   }
 
@@ -606,13 +606,12 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    llvm::append_range(
-        operands,
-        {elementTypeID,
-         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
-         getConstantOp(cooperativeMatrixType.getRows()),
-         getConstantOp(cooperativeMatrixType.getColumns()),
-         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse()))});
+    llvm::append_values(
+        operands, elementTypeID,
+        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+        getConstantOp(cooperativeMatrixType.getRows()),
+        getConstantOp(cooperativeMatrixType.getColumns()),
+        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
     return success();
   }
 
@@ -628,12 +627,11 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    llvm::append_range(
-        operands,
-        {elementTypeID,
-         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
-         getConstantOp(cooperativeMatrixType.getRows()),
-         getConstantOp(cooperativeMatrixType.getColumns())});
+    llvm::append_values(
+        operands, elementTypeID,
+        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+        getConstantOp(cooperativeMatrixType.getRows()),
+        getConstantOp(cooperativeMatrixType.getColumns()));
     return success();
   }
 
@@ -648,13 +646,11 @@ LogicalResult Serializer::prepareBasicType(
       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
       return prepareConstantInt(loc, attr);
     };
-    llvm::append_range(
-        operands,
-        {elementTypeID, getConstantOp(jointMatrixType.getRows()),
-         getConstantOp(jointMatrixType.getColumns()),
-         getConstantOp(
-             static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
-         getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope()))});
+    llvm::append_values(
+        operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
+        getConstantOp(jointMatrixType.getColumns()),
+        getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
+        getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
     return success();
   }
 
@@ -665,7 +661,7 @@ LogicalResult Serializer::prepareBasicType(
       return failure();
     }
     typeEnum = spirv::Opcode::OpTypeMatrix;
-    llvm::append_range(operands, {elementTypeID, matrixType.getNumColumns()});
+    llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
     return success();
   }
 
@@ -1262,9 +1258,10 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
                                          spirv::Decoration decoration,
                                          ArrayRef<uint32_t> params) {
   uint32_t wordCount = 3 + params.size();
-  llvm::append_range(decorations, {spirv::getPrefixedOpcode(
-                                       wordCount, spirv::Opcode::OpDecorate),
-                                   target, static_cast<uint32_t>(decoration)});
+  llvm::append_values(
+      decorations,
+      spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
+      static_cast<uint32_t>(decoration));
   llvm::append_range(decorations, params);
   return success();
 }



More information about the Mlir-commits mailing list