[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Dec 4 02:08:14 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/74251
>From c8e828ad69dc5d6802c581ff62582f130a0d70b9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Sun, 3 Dec 2023 20:01:42 +0000
Subject: [PATCH] [mlir] Add first-class support for scalability in VectorType
dims
Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.
For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:
```c++
while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
}
```
Which would be wrong if you (more naturally) wrote it as:
```c++
auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
});
```
As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.
This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.
Two new methods are added to VectorType:
```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```
These are backed by two new classes: `VectorDim` and `VectorDims`.
`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.
`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).
There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.
With these changes the previous example becomes:
```c++
auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
});
```
Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
---
mlir/include/mlir/IR/BuiltinTypes.h | 192 ++++++++++++++++++
mlir/include/mlir/IR/BuiltinTypes.td | 23 +++
.../ArmSME/sme-tile-type-conversion.mlir | 7 +
.../Conversion/ArmSMEToLLVM/CMakeLists.txt | 20 ++
.../ArmSMEToLLVM/TestTileTypeConversion.cpp | 82 ++++++++
mlir/unittests/IR/ShapedTypeTest.cpp | 101 +++++++++
6 files changed, 425 insertions(+)
create mode 100644 mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
create mode 100644 mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
create mode 100644 mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
namespace llvm {
class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable){};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDim elements.
+ struct Indexer {
+ explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims) {
+ assert(
+ scalableDims.empty() ||
+ sizes.size() == scalableDims.size() &&
+ "expected `scalableDims` to be empty or match `sizes` in length");
+ }
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index){};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ /// Construct from iterator pair.
+ VectorDims(Iterator begin, Iterator end)
+ : VectorDims(VectorDims(begin.getIndexer())
+ .slice(begin.getIndex(), end - begin)) {}
+
+ VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDims slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDims(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDims dropWhile(PredicateT predicate) const {
+ return VectorDims(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+ /// Check for dim equality.
+ bool equals(VectorDims rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ for (VectorDim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return get(sizes, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+ return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Returns the value of the specified dimension (including scalability).
+ VectorDim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ VectorDims getDims() const {
+ return VectorDims(getShape(), getScalableDims());
+ }
+
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir b/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
new file mode 100644
index 0000000000000..b4f702e3aae6d
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
@@ -0,0 +1,7 @@
+
+
+func.func @function_using_sme_size_2d_scalable_vec(%vec: vector<[4]x[4]xi32>) -> vector<[4]x[4]xi32> {
+ %c7 = arith.constant 7 : i32
+ %newVec = vector.insert %c7, %vec[0, 0] : i32 into vector<[4]x[4]xi32>
+ return %newVec : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..c713f085373ab
--- /dev/null
+++ b/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestArmSMEToLLVM
+ TestTileTypeConversion.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEToLLVM
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRLLVMIRTransforms
+ MLIRPass
+ MLIRTestDialect
+ )
+
+target_include_directories(MLIRTestArmSMEToLLVM
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
+ ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
+ )
diff --git a/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp b/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp
new file mode 100644
index 0000000000000..83febe2774d27
--- /dev/null
+++ b/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp
@@ -0,0 +1,82 @@
+//===- TestTileTypeConversion.cpp - Test LLVM Conversion of ArmSME tiles --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+class TestTypeProducerOpConverter
+ : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
+public:
+ using ConvertOpToLLVMPattern<
+ test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::ZeroOp>(op, getVoidPtrType());
+ return success();
+ }
+};
+
+struct TestConvertCallOp
+ : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertCallOp)
+
+ void getDependentDialects(DialectRegistry ®istry) const final {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+ StringRef getArgument() const final { return "test-convert-call-op"; }
+ StringRef getDescription() const final {
+ return "Tests conversion of `func.call` to `llvm.call` in "
+ "presence of custom types";
+ }
+
+ void runOnOperation() override {
+ ModuleOp m = getOperation();
+
+ LowerToLLVMOptions options(m.getContext());
+
+ // Populate type conversions.
+ LLVMTypeConverter typeConverter(m.getContext(), options);
+ typeConverter.addConversion([&](test::TestType type) {
+ return LLVM::LLVMPointerType::get(m.getContext());
+ });
+ typeConverter.addConversion([&](test::SimpleAType type) {
+ return IntegerType::get(type.getContext(), 42);
+ });
+
+ // Populate patterns.
+ RewritePatternSet patterns(m.getContext());
+ populateFuncToLLVMConversionPatterns(typeConverter, patterns);
+ patterns.add<TestTypeProducerOpConverter>(typeConverter);
+
+ // Set target.
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addIllegalDialect<test::TestDialect>();
+ target.addIllegalDialect<func::FuncDialect>();
+
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+TEST(ShapedTypeTest, VectorDims) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+ VectorDim::getFixed(8), VectorDim::getScalable(9),
+ VectorDim::getFixed(1)};
+ VectorType vectorType = VectorType::get(f32, dims);
+
+ // Directly check values
+ {
+ auto dim0 = vectorType.getDim(0);
+ ASSERT_EQ(dim0.getMinSize(), 2);
+ ASSERT_TRUE(dim0.isFixed());
+
+ auto dim1 = vectorType.getDim(1);
+ ASSERT_EQ(dim1.getMinSize(), 4);
+ ASSERT_TRUE(dim1.isScalable());
+
+ auto dim2 = vectorType.getDim(2);
+ ASSERT_EQ(dim2.getMinSize(), 8);
+ ASSERT_TRUE(dim2.isFixed());
+
+ auto dim3 = vectorType.getDim(3);
+ ASSERT_EQ(dim3.getMinSize(), 9);
+ ASSERT_TRUE(dim3.isScalable());
+
+ auto dim4 = vectorType.getDim(4);
+ ASSERT_EQ(dim4.getMinSize(), 1);
+ ASSERT_TRUE(dim4.isFixed());
+ }
+
+ // Test indexing via getDim(idx)
+ {
+ for (unsigned i = 0; i < dims.size(); i++)
+ ASSERT_EQ(vectorType.getDim(i), dims[i]);
+ }
+
+ // Test using VectorDims::Iterator in for-each loop
+ {
+ unsigned i = 0;
+ for (VectorDim dim : vectorType.getDims())
+ ASSERT_EQ(dim, dims[i++]);
+ ASSERT_EQ(i, vectorType.getRank());
+ }
+
+ // Test using VectorDims::Iterator in LLVM iterator helper
+ {
+ for (auto [dim, expectedDim] :
+ llvm::zip_equal(vectorType.getDims(), dims)) {
+ ASSERT_EQ(dim, expectedDim);
+ }
+ }
+
+ // Test dropFront()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropFront();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+ }
+
+ // Test dropBack()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropBack();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i]);
+ }
+
+ // Test front()
+ { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+ // Test back()
+ { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+ // Test dropWhile.
+ {
+ SmallVector<VectorDim> dims{
+ VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+ VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+ VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+ ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+ unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+ // Drop leading unit dims.
+ auto withoutLeadingUnitDims =
+ vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+ SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+ VectorDim::getScalable(4)};
+ ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+ }
+}
+
} // namespace
More information about the Mlir-commits
mailing list