[Mlir-commits] [mlir] 4c28e66 - [ADT] Support appending multiple values (#69891)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 24 10:20:10 PDT 2023
Author: Jakub Kuderski
Date: 2023-10-24T13:20:05-04:00
New Revision: 4c28e666a75805cf453c0d4ed153bfbc779c45c5
URL: https://github.com/llvm/llvm-project/commit/4c28e666a75805cf453c0d4ed153bfbc779c45c5
DIFF: https://github.com/llvm/llvm-project/commit/4c28e666a75805cf453c0d4ed153bfbc779c45c5.diff
LOG: [ADT] Support appending multiple values (#69891)
This is so that we can append multiple values at once without having to
create a temporary array or repeatedly call `push_back`.
Use the new function `append_values` to clean up the SPIR-V serializer
code. (NFC)
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
llvm/unittests/ADT/STLExtrasTest.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index d0b79fa91c03130..1923072960c7008 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2029,14 +2029,22 @@ void erase_value(Container &C, ValueType V) {
C.erase(std::remove(C.begin(), C.end(), V), C.end());
}
-/// Wrapper function to append a range to a container.
+/// Wrapper function to append range `R` to container `C`.
///
/// C.insert(C.end(), R.begin(), R.end());
template <typename Container, typename Range>
-inline void append_range(Container &C, Range &&R) {
+void append_range(Container &C, Range &&R) {
C.insert(C.end(), adl_begin(R), adl_end(R));
}
+/// Appends all `Values` to container `C`.
+template <typename Container, typename... Args>
+void append_values(Container &C, Args &&...Values) {
+ C.reserve(range_size(C) + sizeof...(Args));
+ // Append all values one by one.
+ ((void)C.insert(C.end(), std::forward<Args>(Values)), ...);
+}
+
/// Given a sequence container Cont, replace the range [ContIt, ContEnd) with
/// the range [ValIt, ValEnd) (which is not from the same container).
template<typename Container, typename RandomAccessIterator>
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index c34760d83874daf..7db339e4ef31cdc 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -18,12 +18,14 @@
#include <list>
#include <tuple>
#include <type_traits>
+#include <unordered_set>
#include <utility>
#include <vector>
using namespace llvm;
using testing::ElementsAre;
+using testing::UnorderedElementsAre;
namespace {
@@ -541,6 +543,30 @@ TEST(STLExtrasTest, AppendRange) {
EXPECT_THAT(Str, ElementsAre('a', 'b', 'c', '\0', 'd', 'e', 'f', '\0'));
}
+TEST(STLExtrasTest, AppendValues) {
+ std::vector<int> Vals = {1, 2};
+ append_values(Vals, 3);
+ EXPECT_THAT(Vals, ElementsAre(1, 2, 3));
+
+ append_values(Vals, 4, 5);
+ EXPECT_THAT(Vals, ElementsAre(1, 2, 3, 4, 5));
+
+ std::vector<StringRef> Strs;
+ std::string A = "A";
+ std::string B = "B";
+ std::string C = "C";
+ append_values(Strs, A, B);
+ EXPECT_THAT(Strs, ElementsAre(A, B));
+ append_values(Strs, C);
+ EXPECT_THAT(Strs, ElementsAre(A, B, C));
+
+ std::unordered_set<int> Set;
+ append_values(Set, 1, 2);
+ EXPECT_THAT(Set, UnorderedElementsAre(1, 2));
+ append_values(Set, 3, 1);
+ EXPECT_THAT(Set, UnorderedElementsAre(1, 2, 3));
+}
+
TEST(STLExtrasTest, ADLTest) {
some_namespace::some_struct s{{1, 2, 3, 4, 5}, ""};
some_namespace::some_struct s2{{2, 4, 6, 8, 10}, ""};
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index dad085e21b42727..22fcc4939317be9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
@@ -443,13 +444,13 @@ LogicalResult Serializer::prepareBasicType(
if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
return failure();
- operands.push_back(sampledTypeID);
- operands.push_back(static_cast<uint32_t>(imageType.getDim()));
- operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
- operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
- operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
- operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
- operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
+ llvm::append_values(operands, sampledTypeID,
+ static_cast<uint32_t>(imageType.getDim()),
+ static_cast<uint32_t>(imageType.getDepthInfo()),
+ static_cast<uint32_t>(imageType.getArrayedInfo()),
+ static_cast<uint32_t>(imageType.getSamplingInfo()),
+ static_cast<uint32_t>(imageType.getSamplerUseInfo()),
+ static_cast<uint32_t>(imageType.getImageFormat()));
return success();
}
@@ -605,12 +606,11 @@ LogicalResult Serializer::prepareBasicType(
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
- operands.push_back(elementTypeID);
- operands.push_back(
- getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
- operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
- operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
- operands.push_back(
+ llvm::append_values(
+ operands, elementTypeID,
+ getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+ getConstantOp(cooperativeMatrixType.getRows()),
+ getConstantOp(cooperativeMatrixType.getColumns()),
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
return success();
}
@@ -627,11 +627,11 @@ LogicalResult Serializer::prepareBasicType(
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
- operands.push_back(elementTypeID);
- operands.push_back(
- getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
- operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
- operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
+ llvm::append_values(
+ operands, elementTypeID,
+ getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
+ getConstantOp(cooperativeMatrixType.getRows()),
+ getConstantOp(cooperativeMatrixType.getColumns()));
return success();
}
@@ -646,12 +646,10 @@ LogicalResult Serializer::prepareBasicType(
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
- operands.push_back(elementTypeID);
- operands.push_back(getConstantOp(jointMatrixType.getRows()));
- operands.push_back(getConstantOp(jointMatrixType.getColumns()));
- operands.push_back(getConstantOp(
- static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
- operands.push_back(
+ llvm::append_values(
+ operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
+ getConstantOp(jointMatrixType.getColumns()),
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
return success();
}
@@ -663,8 +661,7 @@ LogicalResult Serializer::prepareBasicType(
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
- operands.push_back(elementTypeID);
- operands.push_back(matrixType.getNumColumns());
+ llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
return success();
}
@@ -1261,11 +1258,11 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
spirv::Decoration decoration,
ArrayRef<uint32_t> params) {
uint32_t wordCount = 3 + params.size();
- decorations.push_back(
- spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
- decorations.push_back(target);
- decorations.push_back(static_cast<uint32_t>(decoration));
- decorations.append(params.begin(), params.end());
+ llvm::append_values(
+ decorations,
+ spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
+ static_cast<uint32_t>(decoration));
+ llvm::append_range(decorations, params);
return success();
}
More information about the Mlir-commits
mailing list