[Mlir-commits] [mlir] [mlir][spirv] Handle failed conversions of struct elements (PR #70005)

Pierre van Houtryve llvmlistbot at llvm.org
Fri Oct 27 06:58:07 PDT 2023


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/70005

>From 4020ea59189025f5b681a381ca7444f9da414979 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 24 Oct 2023 08:48:48 +0200
Subject: [PATCH 1/4] [mlir][spirv] Handle failed conversions of struct
 elements

LLVMStructTypes could be emitted with some null elements.
This caused a crash later in the LLVMDialect verifier.

Now, properly check that all struct elements were successfully
converted before passing them to the LLVMStructType ctor.

See #59990
---
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 26 ++++++++++++-------
 .../spirv-types-to-llvm-invalid.mlir          |  7 +++++
 2 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 60f34f413f587d4..5f752765f6d7f20 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -199,6 +199,16 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
 }
 
+static bool convertTypes(LLVMTypeConverter &converter, const spirv::StructType::ElementTypeRange &types, SmallVectorImpl<Type> &out) {
+  for(const auto &type: types) {
+    if(auto convertedType = converter.convertType(type))
+      out.push_back(convertedType);
+    else
+      return false;
+  }
+  return true;
+}
+
 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
 /// offset to LLVM struct. Otherwise, the conversion is not supported.
 static std::optional<Type>
@@ -207,21 +217,19 @@ convertStructTypeWithOffset(spirv::StructType type,
   if (type != VulkanLayoutUtils::decorateType(type))
     return std::nullopt;
 
-  auto elementsVector = llvm::to_vector<8>(
-      llvm::map_range(type.getElementTypes(), [&](Type elementType) {
-        return converter.convertType(elementType);
-      }));
+  SmallVector<Type> elementsVector;
+  if(!convertTypes(converter, type.getElementTypes(), elementsVector))
+    return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/false);
 }
 
 /// Converts SPIR-V struct with no offset to packed LLVM struct.
-static Type convertStructTypePacked(spirv::StructType type,
+static std::optional<Type> convertStructTypePacked(spirv::StructType type,
                                     LLVMTypeConverter &converter) {
-  auto elementsVector = llvm::to_vector<8>(
-      llvm::map_range(type.getElementTypes(), [&](Type elementType) {
-        return converter.convertType(elementType);
-      }));
+  SmallVector<Type> elementsVector;
+  if(!convertTypes(converter, type.getElementTypes(), elementsVector))
+    return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/true);
 }
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir
index 3965c47ec199fcb..438c90205abedc4 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir
@@ -7,6 +7,13 @@ spirv.func @array_with_unnatural_stride(%arg: !spirv.array<4 x f32, stride=8>) -
 
 // -----
 
+// expected-error at +1 {{failed to legalize operation 'spirv.func' that was explicitly marked illegal}}
+spirv.func @struct_array_with_unnatural_stride(%arg: !spirv.struct<(!spirv.array<4 x f32, stride=8>)>) -> () "None" {
+  spirv.Return
+}
+
+// -----
+
 // expected-error at +1 {{failed to legalize operation 'spirv.func' that was explicitly marked illegal}}
 spirv.func @struct_with_unnatural_offset(%arg: !spirv.struct<(i32[0], i32[8])>) -> () "None" {
   spirv.Return

>From e5f8eb54c85f6e34fdf0f8efb40fe03c10493fa1 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 24 Oct 2023 11:48:46 +0200
Subject: [PATCH 2/4] clang-format

---
 mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 5f752765f6d7f20..87acca4cb2812c9 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -199,9 +199,11 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
 }
 
-static bool convertTypes(LLVMTypeConverter &converter, const spirv::StructType::ElementTypeRange &types, SmallVectorImpl<Type> &out) {
-  for(const auto &type: types) {
-    if(auto convertedType = converter.convertType(type))
+static bool convertTypes(LLVMTypeConverter &converter,
+                         const spirv::StructType::ElementTypeRange &types,
+                         SmallVectorImpl<Type> &out) {
+  for (const auto &type : types) {
+    if (auto convertedType = converter.convertType(type))
       out.push_back(convertedType);
     else
       return false;
@@ -218,17 +220,17 @@ convertStructTypeWithOffset(spirv::StructType type,
     return std::nullopt;
 
   SmallVector<Type> elementsVector;
-  if(!convertTypes(converter, type.getElementTypes(), elementsVector))
+  if (!convertTypes(converter, type.getElementTypes(), elementsVector))
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/false);
 }
 
 /// Converts SPIR-V struct with no offset to packed LLVM struct.
-static std::optional<Type> convertStructTypePacked(spirv::StructType type,
-                                    LLVMTypeConverter &converter) {
+static std::optional<Type>
+convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) {
   SmallVector<Type> elementsVector;
-  if(!convertTypes(converter, type.getElementTypes(), elementsVector))
+  if (!convertTypes(converter, type.getElementTypes(), elementsVector))
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/true);

>From 443b5a953f252c90cea9ccdb0ff4ef433ca0f5dd Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Wed, 25 Oct 2023 09:03:49 +0200
Subject: [PATCH 3/4] Use TypeConverter.convertTypes

---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 22 +------------------
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 16 ++------------
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  6 ++---
 3 files changed, 6 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 07f2f158ecabb6f..4be2582f8fd68cc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -347,27 +347,7 @@ class StructType
 
   Type getElementType(unsigned) const;
 
-  /// Range class for element types.
-  class ElementTypeRange
-      : public ::llvm::detail::indexed_accessor_range_base<
-            ElementTypeRange, const Type *, Type, Type, Type> {
-  private:
-    using RangeBaseT::RangeBaseT;
-
-    /// See `llvm::detail::indexed_accessor_range_base` for details.
-    static const Type *offset_base(const Type *object, ptrdiff_t index) {
-      return object + index;
-    }
-    /// See `llvm::detail::indexed_accessor_range_base` for details.
-    static Type dereference_iterator(const Type *object, ptrdiff_t index) {
-      return object[index];
-    }
-
-    /// Allow base class access to `offset_base` and `dereference_iterator`.
-    friend RangeBaseT;
-  };
-
-  ElementTypeRange getElementTypes() const;
+  TypeRange getElementTypes() const;
 
   bool hasOffset() const;
 
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 87acca4cb2812c9..7c7f9a2f0506012 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -199,18 +199,6 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
 }
 
-static bool convertTypes(LLVMTypeConverter &converter,
-                         const spirv::StructType::ElementTypeRange &types,
-                         SmallVectorImpl<Type> &out) {
-  for (const auto &type : types) {
-    if (auto convertedType = converter.convertType(type))
-      out.push_back(convertedType);
-    else
-      return false;
-  }
-  return true;
-}
-
 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
 /// offset to LLVM struct. Otherwise, the conversion is not supported.
 static std::optional<Type>
@@ -220,7 +208,7 @@ convertStructTypeWithOffset(spirv::StructType type,
     return std::nullopt;
 
   SmallVector<Type> elementsVector;
-  if (!convertTypes(converter, type.getElementTypes(), elementsVector))
+  if (converter.convertTypes(type.getElementTypes(), elementsVector).failed())
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/false);
@@ -230,7 +218,7 @@ convertStructTypeWithOffset(spirv::StructType type,
 static std::optional<Type>
 convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) {
   SmallVector<Type> elementsVector;
-  if (!convertTypes(converter, type.getElementTypes(), elementsVector))
+  if (converter.convertTypes(type.getElementTypes(), elementsVector).failed())
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/true);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..f1bac6490837b93 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1146,9 +1146,9 @@ Type StructType::getElementType(unsigned index) const {
   return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
 }
 
-StructType::ElementTypeRange StructType::getElementTypes() const {
-  return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
-                          getNumElements());
+TypeRange StructType::getElementTypes() const {
+  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
+                   getNumElements());
 }
 
 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }

>From f97f38b3655ee62a5598493b17c5ff8d5a6b112d Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Fri, 27 Oct 2023 15:57:52 +0200
Subject: [PATCH 4/4] use failed()

---
 mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 7c7f9a2f0506012..d07f42baafa5e6a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -208,7 +208,7 @@ convertStructTypeWithOffset(spirv::StructType type,
     return std::nullopt;
 
   SmallVector<Type> elementsVector;
-  if (converter.convertTypes(type.getElementTypes(), elementsVector).failed())
+  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/false);
@@ -218,7 +218,7 @@ convertStructTypeWithOffset(spirv::StructType type,
 static std::optional<Type>
 convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) {
   SmallVector<Type> elementsVector;
-  if (converter.convertTypes(type.getElementTypes(), elementsVector).failed())
+  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
     return std::nullopt;
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/true);



More information about the Mlir-commits mailing list