[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Dec 4 02:08:14 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/74251

>From c8e828ad69dc5d6802c581ff62582f130a0d70b9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Sun, 3 Dec 2023 20:01:42 +0000
Subject: [PATCH] [mlir] Add first-class support for scalability in VectorType
 dims

Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.

For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:

```c++
  while (!newShape.empty() && newShape.front() == 1 &&
         !newScalableDims.front()) {
    newShape = newShape.drop_front(1);
    newScalableDims = newScalableDims.drop_front(1);
  }
```

Which would be wrong if you (more naturally) wrote it as:

```c++
  auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
    return dim == 1;
  });
```

As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.

This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```

These are backed by two new classes: `VectorDim` and `VectorDims`.

`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.

`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.

With these changes the previous example becomes:

```c++
  auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
    return dim == VectorDim::getFixed(1);
  });
```

Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
---
 mlir/include/mlir/IR/BuiltinTypes.h           | 192 ++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypes.td          |  23 +++
 .../ArmSME/sme-tile-type-conversion.mlir      |   7 +
 .../Conversion/ArmSMEToLLVM/CMakeLists.txt    |  20 ++
 .../ArmSMEToLLVM/TestTileTypeConversion.cpp   |  82 ++++++++
 mlir/unittests/IR/ShapedTypeTest.cpp          | 101 +++++++++
 6 files changed, 425 insertions(+)
 create mode 100644 mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
 create mode 100644 mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
 create mode 100644 mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+  explicit constexpr VectorDim(int64_t quantity, bool scalable)
+      : quantity(quantity), scalable(scalable){};
+
+  /// Constructs a new fixed dimension.
+  constexpr static VectorDim getFixed(int64_t quantity) {
+    return VectorDim(quantity, false);
+  }
+
+  /// Constructs a new scalable dimension.
+  constexpr static VectorDim getScalable(int64_t quantity) {
+    return VectorDim(quantity, true);
+  }
+
+  /// Returns true if this dimension is scalable;
+  constexpr bool isScalable() const { return scalable; }
+
+  /// Returns true if this dimension is fixed.
+  constexpr bool isFixed() const { return !isScalable(); }
+
+  /// Returns the minimum number of elements this dimension can contain.
+  constexpr int64_t getMinSize() const { return quantity; }
+
+  /// If this dimension is fixed returns the number of elements, otherwise
+  /// aborts.
+  constexpr int64_t getFixedSize() const {
+    assert(isFixed());
+    return quantity;
+  }
+
+  constexpr bool operator==(VectorDim const &dim) const {
+    return quantity == dim.quantity && scalable == dim.scalable;
+  }
+
+  constexpr bool operator!=(VectorDim const &dim) const {
+    return !(*this == dim);
+  }
+
+  /// Helper class for indexing into a list of sizes (and possibly empty) list
+  /// of scalable dimensions, extracting VectorDim elements.
+  struct Indexer {
+    explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+        : sizes(sizes), scalableDims(scalableDims) {
+      assert(
+          scalableDims.empty() ||
+          sizes.size() == scalableDims.size() &&
+              "expected `scalableDims` to be empty or match `sizes` in length");
+    }
+
+    VectorDim operator[](size_t idx) const {
+      int64_t size = sizes[idx];
+      bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+      return VectorDim(size, scalable);
+    }
+
+    ArrayRef<int64_t> sizes;
+    ArrayRef<bool> scalableDims;
+  };
+
+private:
+  int64_t quantity;
+  bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+  using VectorDim::Indexer::Indexer;
+
+  class Iterator : public llvm::iterator_facade_base<
+                       Iterator, std::random_access_iterator_tag, VectorDim,
+                       std::ptrdiff_t, VectorDim, VectorDim> {
+  public:
+    Iterator(VectorDim::Indexer indexer, size_t index)
+        : indexer(indexer), index(index){};
+
+    // Iterator boilerplate.
+    ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+    bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+    bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+    Iterator &operator+=(ptrdiff_t offset) {
+      index += offset;
+      return *this;
+    }
+    Iterator &operator-=(ptrdiff_t offset) {
+      index -= offset;
+      return *this;
+    }
+    VectorDim operator*() const { return indexer[index]; }
+
+    VectorDim::Indexer getIndexer() const { return indexer; }
+    ptrdiff_t getIndex() const { return index; }
+
+  private:
+    VectorDim::Indexer indexer;
+    ptrdiff_t index;
+  };
+
+  /// Construct from iterator pair.
+  VectorDims(Iterator begin, Iterator end)
+      : VectorDims(VectorDims(begin.getIndexer())
+                       .slice(begin.getIndex(), end - begin)) {}
+
+  VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+  Iterator begin() const { return Iterator(*this, 0); }
+  Iterator end() const { return Iterator(*this, size()); }
+
+  /// Check if the dims are empty.
+  bool empty() const { return sizes.empty(); }
+
+  /// Get the number of dims.
+  size_t size() const { return sizes.size(); }
+
+  /// Return the first dim.
+  VectorDim front() const { return (*this)[0]; }
+
+  /// Return the last dim.
+  VectorDim back() const { return (*this)[size() - 1]; }
+
+  /// Chop of thie first \p n dims, and keep the remaining \p m
+  /// dims.
+  VectorDims slice(size_t n, size_t m) const {
+    ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+    ArrayRef<bool> newScalableDims =
+        scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+    return VectorDims(newSizes, newScalableDims);
+  }
+
+  /// Drop the first \p n dims.
+  VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+  /// Drop the last \p n dims.
+  VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+  /// Return copy of *this with the first n dims matching the predicate removed.
+  template <class PredicateT>
+  VectorDims dropWhile(PredicateT predicate) const {
+    return VectorDims(llvm::find_if_not(*this, predicate), end());
+  }
+
+  /// Return the underlying sizes.
+  ArrayRef<int64_t> getSizes() const { return sizes; }
+
+  /// Return the underlying scalable dims.
+  ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+  /// Check for dim equality.
+  bool equals(VectorDims rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Check for dim equality.
+  bool equals(ArrayRef<VectorDim> rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return !(lhs == rhs);
+}
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         scalableDims = isScalableVec;
       }
       return $_get(elementType.getContext(), shape, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+      SmallVector<int64_t> sizes;
+      SmallVector<bool> scalableDims;
+      for (VectorDim dim : shape) {
+        sizes.push_back(dim.getMinSize());
+        scalableDims.push_back(dim.isScalable());
+      }
+      return get(sizes, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+      return get(shape.getSizes(), elementType, shape.getScalableDims());
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Returns the value of the specified dimension (including scalability).
+    VectorDim getDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid dim index for vector type");
+      return getDims()[idx];
+    }
+
+    /// Returns the dimensions of this vector type (including scalability).
+    VectorDims getDims() const {
+      return VectorDims(getShape(), getScalableDims());
+    }
+
     /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
diff --git a/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir b/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
new file mode 100644
index 0000000000000..b4f702e3aae6d
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/sme-tile-type-conversion.mlir
@@ -0,0 +1,7 @@
+
+
+func.func @function_using_sme_size_2d_scalable_vec(%vec: vector<[4]x[4]xi32>) -> vector<[4]x[4]xi32> {
+  %c7 = arith.constant 7 : i32
+  %newVec = vector.insert %c7, %vec[0, 0] : i32 into vector<[4]x[4]xi32>
+  return %newVec : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..c713f085373ab
--- /dev/null
+++ b/mlir/test/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestArmSMEToLLVM
+  TestTileTypeConversion.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMEToLLVM
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRLLVMIRTransforms
+  MLIRPass
+  MLIRTestDialect
+  )
+
+target_include_directories(MLIRTestArmSMEToLLVM
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
+  )
diff --git a/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp b/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp
new file mode 100644
index 0000000000000..83febe2774d27
--- /dev/null
+++ b/mlir/test/lib/Conversion/ArmSMEToLLVM/TestTileTypeConversion.cpp
@@ -0,0 +1,82 @@
+//===- TestTileTypeConversion.cpp - Test LLVM Conversion of ArmSME tiles --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+class TestTypeProducerOpConverter
+    : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
+public:
+  using ConvertOpToLLVMPattern<
+      test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::ZeroOp>(op, getVoidPtrType());
+    return success();
+  }
+};
+
+struct TestConvertCallOp
+    : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertCallOp)
+
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+  StringRef getArgument() const final { return "test-convert-call-op"; }
+  StringRef getDescription() const final {
+    return "Tests conversion of `func.call` to `llvm.call` in "
+           "presence of custom types";
+  }
+
+  void runOnOperation() override {
+    ModuleOp m = getOperation();
+
+    LowerToLLVMOptions options(m.getContext());
+
+    // Populate type conversions.
+    LLVMTypeConverter typeConverter(m.getContext(), options);
+    typeConverter.addConversion([&](test::TestType type) {
+      return LLVM::LLVMPointerType::get(m.getContext());
+    });
+    typeConverter.addConversion([&](test::SimpleAType type) {
+      return IntegerType::get(type.getContext(), 42);
+    });
+
+    // Populate patterns.
+    RewritePatternSet patterns(m.getContext());
+    populateFuncToLLVMConversionPatterns(typeConverter, patterns);
+    patterns.add<TestTypeProducerOpConverter>(typeConverter);
+
+    // Set target.
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addIllegalDialect<test::TestDialect>();
+    target.addIllegalDialect<func::FuncDialect>();
+
+    if (failed(applyPartialConversion(m, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+TEST(ShapedTypeTest, VectorDims) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+                              VectorDim::getFixed(8), VectorDim::getScalable(9),
+                              VectorDim::getFixed(1)};
+  VectorType vectorType = VectorType::get(f32, dims);
+
+  // Directly check values
+  {
+    auto dim0 = vectorType.getDim(0);
+    ASSERT_EQ(dim0.getMinSize(), 2);
+    ASSERT_TRUE(dim0.isFixed());
+
+    auto dim1 = vectorType.getDim(1);
+    ASSERT_EQ(dim1.getMinSize(), 4);
+    ASSERT_TRUE(dim1.isScalable());
+
+    auto dim2 = vectorType.getDim(2);
+    ASSERT_EQ(dim2.getMinSize(), 8);
+    ASSERT_TRUE(dim2.isFixed());
+
+    auto dim3 = vectorType.getDim(3);
+    ASSERT_EQ(dim3.getMinSize(), 9);
+    ASSERT_TRUE(dim3.isScalable());
+
+    auto dim4 = vectorType.getDim(4);
+    ASSERT_EQ(dim4.getMinSize(), 1);
+    ASSERT_TRUE(dim4.isFixed());
+  }
+
+  // Test indexing via getDim(idx)
+  {
+    for (unsigned i = 0; i < dims.size(); i++)
+      ASSERT_EQ(vectorType.getDim(i), dims[i]);
+  }
+
+  // Test using VectorDims::Iterator in for-each loop
+  {
+    unsigned i = 0;
+    for (VectorDim dim : vectorType.getDims())
+      ASSERT_EQ(dim, dims[i++]);
+    ASSERT_EQ(i, vectorType.getRank());
+  }
+
+  // Test using VectorDims::Iterator in LLVM iterator helper
+  {
+    for (auto [dim, expectedDim] :
+         llvm::zip_equal(vectorType.getDims(), dims)) {
+      ASSERT_EQ(dim, expectedDim);
+    }
+  }
+
+  // Test dropFront()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropFront();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+  }
+
+  // Test dropBack()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropBack();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i]);
+  }
+
+  // Test front()
+  { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+  // Test back()
+  { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+  // Test dropWhile.
+  {
+    SmallVector<VectorDim> dims{
+        VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+        VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+    VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+    ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+              unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+    // Drop leading unit dims.
+    auto withoutLeadingUnitDims =
+        vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+            [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+    SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+                                        VectorDim::getScalable(4)};
+    ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+  }
+}
+
 } // namespace



More information about the Mlir-commits mailing list