[Mlir-commits] [mlir] [mlir][llvm] Return failure from type converter for n-D scalable vectors (PR #65450)

Cullen Rhodes llvmlistbot at llvm.org
Wed Sep 6 00:50:43 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/65450:

This patch changes vector type conversion to return failure on n-D scalable vector types instead of asserting.

This is an alternative approach to #65261 that aims to enable lowering of Vector ops directly to ArmSME intrinsics where possible, and seems more consistent with other type conversions. It's trivial to hit the assert at the moment and it could be interpreted as n-D scalable vector types being a bug, when they're valid types in the Vector dialect.

By returning failure it will generally fail more gracefully, particularly for release builds or other builds where assertions are disabled.

>From 885666a9c3ef3cad10ae3b026672afbc2894db33 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 6 Sep 2023 07:20:59 +0000
Subject: [PATCH] [mlir][llvm] Return failure from type converter for n-D
 scalable vectors

This patch changes vector type conversion to return failure on n-D
scalable vector types instead of asserting.

This is an alternative approach to #65261 that aims to enable lowering
of Vector ops directly to ArmSME intrinsics where possible, and seems
more consistent with other type conversions. It's trivial to hit the
assert at the moment and it could be interpreted as n-D scalable vector
types being a bug, when they're valid types in the Vector dialect.

By returning failure it will generally fail more gracefully,
particularly for release builds or other builds where assertions are
disabled.
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  2 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 19 +++++++++++--------
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d9..2a4327535c68750 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
   Type convertMemRefToBarePtr(BaseMemRefType type) const;
 
   /// Convert a 1D vector type into an LLVM vector type.
-  Type convertVectorType(VectorType type) const;
+  FailureOr<Type> convertVectorType(VectorType type) const;
 
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..49e0513e629d951 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addConversion([&](MemRefType type) { return convertMemRefType(type); });
   addConversion(
       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
-  addConversion([&](VectorType type) { return convertVectorType(type); });
+  addConversion([&](VectorType type) -> std::optional<Type> {
+    FailureOr<Type> llvmType = convertVectorType(type);
+    if (failed(llvmType))
+      return std::nullopt;
+    return llvmType;
+  });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
 ///  * 1-D `vector<axT>` remains as is while,
 ///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
 ///    `!llvm.array<ax...array<jxvector<kxT>>>`.
-/// As LLVM does not support arrays of scalable vectors, it is assumed that
-/// scalable vectors are always 1-D. This condition could be relaxed once the
-/// missing functionality is added in LLVM
-Type LLVMTypeConverter::convertVectorType(VectorType type) const {
+/// Returns failure for n-D scalable vector types as LLVM does not support
+/// arrays of scalable vectors.
+FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   auto elementType = convertType(type.getElementType());
   if (!elementType)
     return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
-  assert(
-      (!type.isScalable() || (type.getRank() == 1)) &&
-      "expected 1-D scalable vector (n-D scalable vectors are not supported)");
+  if (type.isScalable() && (type.getRank() > 1))
+    return failure();
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);



More information about the Mlir-commits mailing list