[Mlir-commits] [mlir] 38eb55a - [mlir][llvm] Return failure from type converter for n-D scalable vectors (#65450)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 01:31:55 PDT 2023


Author: Cullen Rhodes
Date: 2023-09-11T09:31:48+01:00
New Revision: 38eb55a130e8056b56796bcd5c937c862939940c

URL: https://github.com/llvm/llvm-project/commit/38eb55a130e8056b56796bcd5c937c862939940c
DIFF: https://github.com/llvm/llvm-project/commit/38eb55a130e8056b56796bcd5c937c862939940c.diff

LOG: [mlir][llvm] Return failure from type converter for n-D scalable vectors (#65450)

This patch changes vector type conversion to return failure on n-D
scalable vector types instead of asserting.

This is an alternative approach to #65261 that aims to enable lowering
of Vector ops directly to ArmSME intrinsics where possible, and seems
more consistent with other type conversions. It's trivial to hit the
assert at the moment and it could be interpreted as n-D scalable vector
types being a bug, when they're valid types in the Vector dialect.

By returning failure it will generally fail more gracefully,
particularly for release builds or other builds where assertions are
disabled.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d9..2a4327535c68750 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
   Type convertMemRefToBarePtr(BaseMemRefType type) const;
 
   /// Convert a 1D vector type into an LLVM vector type.
-  Type convertVectorType(VectorType type) const;
+  FailureOr<Type> convertVectorType(VectorType type) const;
 
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..49e0513e629d951 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addConversion([&](MemRefType type) { return convertMemRefType(type); });
   addConversion(
       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
-  addConversion([&](VectorType type) { return convertVectorType(type); });
+  addConversion([&](VectorType type) -> std::optional<Type> {
+    FailureOr<Type> llvmType = convertVectorType(type);
+    if (failed(llvmType))
+      return std::nullopt;
+    return llvmType;
+  });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
 ///  * 1-D `vector<axT>` remains as is while,
 ///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
 ///    `!llvm.array<ax...array<jxvector<kxT>>>`.
-/// As LLVM does not support arrays of scalable vectors, it is assumed that
-/// scalable vectors are always 1-D. This condition could be relaxed once the
-/// missing functionality is added in LLVM
-Type LLVMTypeConverter::convertVectorType(VectorType type) const {
+/// Returns failure for n-D scalable vector types as LLVM does not support
+/// arrays of scalable vectors.
+FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   auto elementType = convertType(type.getElementType());
   if (!elementType)
     return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
-  assert(
-      (!type.isScalable() || (type.getRank() == 1)) &&
-      "expected 1-D scalable vector (n-D scalable vectors are not supported)");
+  if (type.isScalable() && (type.getRank() > 1))
+    return failure();
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);


        


More information about the Mlir-commits mailing list