[Mlir-commits] [mlir] 9bad924 - [mlir][LLVMIR] Apply SubElementTypeInterface on suitable types

Min-Yih Hsu llvmlistbot at llvm.org
Wed Jun 29 13:58:40 PDT 2022


Author: Min-Yih Hsu
Date: 2022-06-29T13:58:02-07:00
New Revision: 9bad9248ed3038eaa0cd0aeebb19566233e0f3e6

URL: https://github.com/llvm/llvm-project/commit/9bad9248ed3038eaa0cd0aeebb19566233e0f3e6
DIFF: https://github.com/llvm/llvm-project/commit/9bad9248ed3038eaa0cd0aeebb19566233e0f3e6.diff

LOG: [mlir][LLVMIR] Apply SubElementTypeInterface on suitable types

This feature is tested by unit test since not many places in the codebase
use SubElementTypeInterface.

Differential Revision: https://reviews.llvm.org/D127539

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 50537f6c9abed..e415061768fe4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 
+#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
@@ -73,7 +74,8 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 /// type.
 class LLVMArrayType
     : public Type::TypeBase<LLVMArrayType, Type, detail::LLVMTypeAndSizeStorage,
-                            DataLayoutTypeInterface::Trait> {
+                            DataLayoutTypeInterface::Trait,
+                            SubElementTypeInterface::Trait> {
 public:
   /// Inherit base constructors.
   using Base::Base;
@@ -111,6 +113,9 @@ class LLVMArrayType
 
   unsigned getPreferredAlignment(const DataLayout &dataLayout,
                                  DataLayoutEntryListRef params) const;
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -120,9 +125,9 @@ class LLVMArrayType
 /// LLVM dialect function type. It consists of a single return type (unlike MLIR
 /// which can have multiple), a list of parameter types and can optionally be
 /// variadic.
-class LLVMFunctionType
-    : public Type::TypeBase<LLVMFunctionType, Type,
-                            detail::LLVMFunctionTypeStorage> {
+class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
+                                               detail::LLVMFunctionTypeStorage,
+                                               SubElementTypeInterface::Trait> {
 public:
   /// Inherit base constructors.
   using Base::Base;
@@ -150,11 +155,11 @@ class LLVMFunctionType
   LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
 
   /// Returns the result type of the function.
-  Type getReturnType();
+  Type getReturnType() const;
 
   /// Returns the result type of the function as an ArrayRef, enabling better
   /// integration with generic MLIR utilities.
-  ArrayRef<Type> getReturnTypes();
+  ArrayRef<Type> getReturnTypes() const;
 
   /// Returns the number of arguments to the function.
   unsigned getNumParams();
@@ -163,12 +168,15 @@ class LLVMFunctionType
   Type getParamType(unsigned i);
 
   /// Returns a list of argument types of the function.
-  ArrayRef<Type> getParams();
+  ArrayRef<Type> getParams() const;
   ArrayRef<Type> params() { return getParams(); }
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               Type result, ArrayRef<Type> arguments, bool);
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -179,9 +187,10 @@ class LLVMFunctionType
 /// object in memory. Pointers may be opaque or parameterized by the element
 /// type. Both opaque and non-opaque pointers are additionally parameterized by
 /// the address space.
-class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
-                                              detail::LLVMPointerTypeStorage,
-                                              DataLayoutTypeInterface::Trait> {
+class LLVMPointerType
+    : public Type::TypeBase<
+          LLVMPointerType, Type, detail::LLVMPointerTypeStorage,
+          DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> {
 public:
   /// Inherit base constructors.
   using Base::Base;
@@ -232,6 +241,9 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
                      DataLayoutEntryListRef newLayout) const;
   LogicalResult verifyEntries(DataLayoutEntryListRef entries,
                               Location loc) const;
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -265,6 +277,7 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
 class LLVMStructType
     : public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
                             DataLayoutTypeInterface::Trait,
+                            SubElementTypeInterface::Trait,
                             TypeTrait::IsMutable> {
 public:
   /// Inherit base constructors.
@@ -359,6 +372,9 @@ class LLVMStructType
 
   LogicalResult verifyEntries(DataLayoutEntryListRef entries,
                               Location loc) const;
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -369,7 +385,8 @@ class LLVMStructType
 /// length that can be processed as one.
 class LLVMFixedVectorType
     : public Type::TypeBase<LLVMFixedVectorType, Type,
-                            detail::LLVMTypeAndSizeStorage> {
+                            detail::LLVMTypeAndSizeStorage,
+                            SubElementTypeInterface::Trait> {
 public:
   /// Inherit base constructor.
   using Base::Base;
@@ -388,7 +405,7 @@ class LLVMFixedVectorType
   static bool isValidElementType(Type type);
 
   /// Returns the element type of the vector.
-  Type getElementType();
+  Type getElementType() const;
 
   /// Returns the number of elements in the fixed vector.
   unsigned getNumElements();
@@ -396,6 +413,9 @@ class LLVMFixedVectorType
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               Type elementType, unsigned numElements);
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -407,7 +427,8 @@ class LLVMFixedVectorType
 /// elements can be processed as one in SIMD context.
 class LLVMScalableVectorType
     : public Type::TypeBase<LLVMScalableVectorType, Type,
-                            detail::LLVMTypeAndSizeStorage> {
+                            detail::LLVMTypeAndSizeStorage,
+                            SubElementTypeInterface::Trait> {
 public:
   /// Inherit base constructor.
   using Base::Base;
@@ -424,7 +445,7 @@ class LLVMScalableVectorType
   static bool isValidElementType(Type type);
 
   /// Returns the element type of the vector.
-  Type getElementType();
+  Type getElementType() const;
 
   /// Returns the scaling factor of the number of elements in the vector. The
   /// vector contains at least the resulting number of elements, or any non-zero
@@ -434,6 +455,9 @@ class LLVMScalableVectorType
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               Type elementType, unsigned minNumElements);
+
+  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+                                function_ref<void(Type)> walkTypesFn) const;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b02d53a2efae7..49d2d8d24963b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -86,6 +86,12 @@ LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
   return dataLayout.getTypePreferredAlignment(getElementType());
 }
 
+void LLVMArrayType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // Function type.
 //===----------------------------------------------------------------------===//
@@ -119,8 +125,10 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
   return get(results[0], llvm::to_vector(inputs), isVarArg());
 }
 
-Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
-ArrayRef<Type> LLVMFunctionType::getReturnTypes() {
+Type LLVMFunctionType::getReturnType() const {
+  return getImpl()->getReturnType();
+}
+ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
   return getImpl()->getReturnType();
 }
 
@@ -134,7 +142,7 @@ Type LLVMFunctionType::getParamType(unsigned i) {
 
 bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
 
-ArrayRef<Type> LLVMFunctionType::getParams() {
+ArrayRef<Type> LLVMFunctionType::getParams() const {
   return getImpl()->getArgumentTypes();
 }
 
@@ -151,6 +159,13 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+void LLVMFunctionType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
+    walkTypesFn(type);
+}
+
 //===----------------------------------------------------------------------===//
 // Pointer type.
 //===----------------------------------------------------------------------===//
@@ -353,6 +368,12 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
   return success();
 }
 
+void LLVMPointerType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // Struct type.
 //===----------------------------------------------------------------------===//
@@ -589,6 +610,13 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
   return mlir::success();
 }
 
+void LLVMStructType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  for (Type type : getBody())
+    walkTypesFn(type);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector types.
 //===----------------------------------------------------------------------===//
@@ -621,7 +649,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           numElements);
 }
 
-Type LLVMFixedVectorType::getElementType() {
+Type LLVMFixedVectorType::getElementType() const {
   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
 }
 
@@ -640,6 +668,12 @@ LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
       emitError, elementType, numElements);
 }
 
+void LLVMFixedVectorType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMScalableVectorType.
 //===----------------------------------------------------------------------===//
@@ -658,7 +692,7 @@ LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           minNumElements);
 }
 
-Type LLVMScalableVectorType::getElementType() {
+Type LLVMScalableVectorType::getElementType() const {
   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
 }
 
@@ -680,6 +714,12 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
       emitError, elementType, numElements);
 }
 
+void LLVMScalableVectorType::walkImmediateSubElements(
+    function_ref<void(Attribute)> walkAttrsFn,
+    function_ref<void(Type)> walkTypesFn) const {
+  walkTypesFn(getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
index 9c0ea4f14d766..75c6fd004e3d4 100644
--- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -18,3 +18,46 @@ TEST_F(LLVMIRTest, IsStructTypeMutable) {
   ASSERT_TRUE(bool(structTy));
   ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
 }
+
+TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
+  auto fooStructTy = LLVMStructType::getIdentified(&context, "foo");
+  ASSERT_TRUE(bool(fooStructTy));
+  auto barStructTy = LLVMStructType::getIdentified(&context, "bar");
+  ASSERT_TRUE(bool(barStructTy));
+
+  // Created two structs that are referencing each other.
+  Type fooBody[] = {LLVMPointerType::get(barStructTy)};
+  ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false)));
+  Type barBody[] = {LLVMPointerType::get(fooStructTy)};
+  ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false)));
+
+  auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
+  ASSERT_TRUE(bool(subElementInterface));
+  // Test if walkSubElements goes into infinite loops.
+  SmallVector<Type, 4> subElementTypes;
+  subElementInterface.walkSubElements(
+      [](Attribute attr) {},
+      [&](Type type) { subElementTypes.push_back(type); });
+  // We don't record LLVMPointerType (because it's immutable), thus
+  // !llvm.ptr<struct<"bar",...>> will be visited twice.
+  ASSERT_EQ(subElementTypes.size(), 5U);
+
+  // !llvm.ptr<struct<"bar",...>>
+  ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
+
+  // !llvm.struct<"foo",...>
+  auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
+  ASSERT_TRUE(bool(structType));
+  ASSERT_TRUE(structType.getName().equals("foo"));
+
+  // !llvm.ptr<struct<"foo",...>>
+  ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
+
+  // !llvm.struct<"bar",...>
+  structType = subElementTypes[3].dyn_cast<LLVMStructType>();
+  ASSERT_TRUE(bool(structType));
+  ASSERT_TRUE(structType.getName().equals("bar"));
+
+  // !llvm.ptr<struct<"bar",...>>
+  ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
+}


        


More information about the Mlir-commits mailing list