[Mlir-commits] [mlir] 9bad924 - [mlir][LLVMIR] Apply SubElementTypeInterface on suitable types
Min-Yih Hsu
llvmlistbot at llvm.org
Wed Jun 29 13:58:40 PDT 2022
Author: Min-Yih Hsu
Date: 2022-06-29T13:58:02-07:00
New Revision: 9bad9248ed3038eaa0cd0aeebb19566233e0f3e6
URL: https://github.com/llvm/llvm-project/commit/9bad9248ed3038eaa0cd0aeebb19566233e0f3e6
DIFF: https://github.com/llvm/llvm-project/commit/9bad9248ed3038eaa0cd0aeebb19566233e0f3e6.diff
LOG: [mlir][LLVMIR] Apply SubElementTypeInterface on suitable types
This feature is tested by unit test since not many places in the codebase
use SubElementTypeInterface.
Differential Revision: https://reviews.llvm.org/D127539
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 50537f6c9abed..e415061768fe4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -73,7 +74,8 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
/// type.
class LLVMArrayType
: public Type::TypeBase<LLVMArrayType, Type, detail::LLVMTypeAndSizeStorage,
- DataLayoutTypeInterface::Trait> {
+ DataLayoutTypeInterface::Trait,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
@@ -111,6 +113,9 @@ class LLVMArrayType
unsigned getPreferredAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
@@ -120,9 +125,9 @@ class LLVMArrayType
/// LLVM dialect function type. It consists of a single return type (unlike MLIR
/// which can have multiple), a list of parameter types and can optionally be
/// variadic.
-class LLVMFunctionType
- : public Type::TypeBase<LLVMFunctionType, Type,
- detail::LLVMFunctionTypeStorage> {
+class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
+ detail::LLVMFunctionTypeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
@@ -150,11 +155,11 @@ class LLVMFunctionType
LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
/// Returns the result type of the function.
- Type getReturnType();
+ Type getReturnType() const;
/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
- ArrayRef<Type> getReturnTypes();
+ ArrayRef<Type> getReturnTypes() const;
/// Returns the number of arguments to the function.
unsigned getNumParams();
@@ -163,12 +168,15 @@ class LLVMFunctionType
Type getParamType(unsigned i);
/// Returns a list of argument types of the function.
- ArrayRef<Type> getParams();
+ ArrayRef<Type> getParams() const;
ArrayRef<Type> params() { return getParams(); }
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type result, ArrayRef<Type> arguments, bool);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
@@ -179,9 +187,10 @@ class LLVMFunctionType
/// object in memory. Pointers may be opaque or parameterized by the element
/// type. Both opaque and non-opaque pointers are additionally parameterized by
/// the address space.
-class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
- detail::LLVMPointerTypeStorage,
- DataLayoutTypeInterface::Trait> {
+class LLVMPointerType
+ : public Type::TypeBase<
+ LLVMPointerType, Type, detail::LLVMPointerTypeStorage,
+ DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
@@ -232,6 +241,9 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
DataLayoutEntryListRef newLayout) const;
LogicalResult verifyEntries(DataLayoutEntryListRef entries,
Location loc) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
@@ -265,6 +277,7 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait,
+ SubElementTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
@@ -359,6 +372,9 @@ class LLVMStructType
LogicalResult verifyEntries(DataLayoutEntryListRef entries,
Location loc) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
@@ -369,7 +385,8 @@ class LLVMStructType
/// length that can be processed as one.
class LLVMFixedVectorType
: public Type::TypeBase<LLVMFixedVectorType, Type,
- detail::LLVMTypeAndSizeStorage> {
+ detail::LLVMTypeAndSizeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructor.
using Base::Base;
@@ -388,7 +405,7 @@ class LLVMFixedVectorType
static bool isValidElementType(Type type);
/// Returns the element type of the vector.
- Type getElementType();
+ Type getElementType() const;
/// Returns the number of elements in the fixed vector.
unsigned getNumElements();
@@ -396,6 +413,9 @@ class LLVMFixedVectorType
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
@@ -407,7 +427,8 @@ class LLVMFixedVectorType
/// elements can be processed as one in SIMD context.
class LLVMScalableVectorType
: public Type::TypeBase<LLVMScalableVectorType, Type,
- detail::LLVMTypeAndSizeStorage> {
+ detail::LLVMTypeAndSizeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructor.
using Base::Base;
@@ -424,7 +445,7 @@ class LLVMScalableVectorType
static bool isValidElementType(Type type);
/// Returns the element type of the vector.
- Type getElementType();
+ Type getElementType() const;
/// Returns the scaling factor of the number of elements in the vector. The
/// vector contains at least the resulting number of elements, or any non-zero
@@ -434,6 +455,9 @@ class LLVMScalableVectorType
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned minNumElements);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b02d53a2efae7..49d2d8d24963b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -86,6 +86,12 @@ LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
return dataLayout.getTypePreferredAlignment(getElementType());
}
+void LLVMArrayType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//
@@ -119,8 +125,10 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
return get(results[0], llvm::to_vector(inputs), isVarArg());
}
-Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
-ArrayRef<Type> LLVMFunctionType::getReturnTypes() {
+Type LLVMFunctionType::getReturnType() const {
+ return getImpl()->getReturnType();
+}
+ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
return getImpl()->getReturnType();
}
@@ -134,7 +142,7 @@ Type LLVMFunctionType::getParamType(unsigned i) {
bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
-ArrayRef<Type> LLVMFunctionType::getParams() {
+ArrayRef<Type> LLVMFunctionType::getParams() const {
return getImpl()->getArgumentTypes();
}
@@ -151,6 +159,13 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+void LLVMFunctionType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
+ walkTypesFn(type);
+}
+
//===----------------------------------------------------------------------===//
// Pointer type.
//===----------------------------------------------------------------------===//
@@ -353,6 +368,12 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
return success();
}
+void LLVMPointerType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Struct type.
//===----------------------------------------------------------------------===//
@@ -589,6 +610,13 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
return mlir::success();
}
+void LLVMStructType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ for (Type type : getBody())
+ walkTypesFn(type);
+}
+
//===----------------------------------------------------------------------===//
// Vector types.
//===----------------------------------------------------------------------===//
@@ -621,7 +649,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
numElements);
}
-Type LLVMFixedVectorType::getElementType() {
+Type LLVMFixedVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
@@ -640,6 +668,12 @@ LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
emitError, elementType, numElements);
}
+void LLVMFixedVectorType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// LLVMScalableVectorType.
//===----------------------------------------------------------------------===//
@@ -658,7 +692,7 @@ LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
minNumElements);
}
-Type LLVMScalableVectorType::getElementType() {
+Type LLVMScalableVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
@@ -680,6 +714,12 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
emitError, elementType, numElements);
}
+void LLVMScalableVectorType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
index 9c0ea4f14d766..75c6fd004e3d4 100644
--- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -18,3 +18,46 @@ TEST_F(LLVMIRTest, IsStructTypeMutable) {
ASSERT_TRUE(bool(structTy));
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
}
+
+TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
+ auto fooStructTy = LLVMStructType::getIdentified(&context, "foo");
+ ASSERT_TRUE(bool(fooStructTy));
+ auto barStructTy = LLVMStructType::getIdentified(&context, "bar");
+ ASSERT_TRUE(bool(barStructTy));
+
+ // Created two structs that are referencing each other.
+ Type fooBody[] = {LLVMPointerType::get(barStructTy)};
+ ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false)));
+ Type barBody[] = {LLVMPointerType::get(fooStructTy)};
+ ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false)));
+
+ auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
+ ASSERT_TRUE(bool(subElementInterface));
+ // Test if walkSubElements goes into infinite loops.
+ SmallVector<Type, 4> subElementTypes;
+ subElementInterface.walkSubElements(
+ [](Attribute attr) {},
+ [&](Type type) { subElementTypes.push_back(type); });
+ // We don't record LLVMPointerType (because it's immutable), thus
+ // !llvm.ptr<struct<"bar",...>> will be visited twice.
+ ASSERT_EQ(subElementTypes.size(), 5U);
+
+ // !llvm.ptr<struct<"bar",...>>
+ ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
+
+ // !llvm.struct<"foo",...>
+ auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
+ ASSERT_TRUE(bool(structType));
+ ASSERT_TRUE(structType.getName().equals("foo"));
+
+ // !llvm.ptr<struct<"foo",...>>
+ ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
+
+ // !llvm.struct<"bar",...>
+ structType = subElementTypes[3].dyn_cast<LLVMStructType>();
+ ASSERT_TRUE(bool(structType));
+ ASSERT_TRUE(structType.getName().equals("bar"));
+
+ // !llvm.ptr<struct<"bar",...>>
+ ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
+}
More information about the Mlir-commits
mailing list