[Mlir-commits] [mlir] [mlir][IR] Experiment: Allow ptr as vector element type (PR #125690)
Matthias Springer
llvmlistbot at llvm.org
Tue Feb 4 06:21:06 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/125690
None
>From 5b7422ad037d11f3c7a82f8d0b5007f4f2b3eee9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 4 Feb 2025 15:20:01 +0100
Subject: [PATCH] [mlir][IR] Experiment: Allow ptr as vector element type
---
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 1 +
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 4 +++-
mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td | 4 +++-
mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h | 1 +
mlir/include/mlir/IR/BuiltinTypes.h | 4 ++++
mlir/include/mlir/IR/BuiltinTypes.td | 7 ++++++-
mlir/include/mlir/IR/CommonTypeConstraints.td | 2 ++
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 3 ++-
mlir/test/IR/test-verifiers-type.mlir | 15 +++++++++++++++
9 files changed, 37 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 8b380751c2f9d6..bca0feb45aab2a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 77c8035ce3d71a..7c5bea568c8399 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -257,7 +258,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
- "getIndexBitwidth", "areCompatible", "verifyEntries"]>]> {
+ "getIndexBitwidth", "areCompatible", "verifyEntries"]>,
+ PointerLike]> {
let summary = "LLVM pointer type";
let description = [{
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 14d72c3001d919..a849960fd72eb6 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -12,6 +12,7 @@
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
@@ -38,7 +39,8 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
- "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+ "areCompatible", "getIndexBitwidth", "verifyEntries"]>,
+ PointerLike
]> {
let summary = "pointer type";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
index 264a97c80722a2..4b86eaba9eed25 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d2..7f20878d02f677 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,6 +43,10 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
+/// Type trait indicating that the type is a pointer-like type.
+template <typename ConcreteType>
+class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
+
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9a..b066271e97d4ac 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
let cppNamespace = "::mlir";
}
+/// Type trait indicating that the type is a pointer-like type.
+def PointerLike : NativeTypeTrait<"PointerLike"> {
+ let cppNamespace = "::mlir";
+}
+
//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//
@@ -1238,7 +1243,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
let cppFunctionName = "isValidVectorTypeElementType";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 82e335e30b6fa4..6f41daae621947 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
"::mlir::IndexType">,
BuildableType<"$_builder.getIndexType()">;
+def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
+
// Any signless integer type or index type.
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 453b206de294e4..c57d0bcf2e5559 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -876,7 +876,8 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
}
bool mlir::LLVM::isCompatibleVectorType(Type type) {
- if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
+ if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, LLVMPointerType>(
+ type))
return true;
if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 96d0005eb7a19d..eb88a259b64af0 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -7,3 +7,18 @@
// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
"test.type_producer"() : () -> !test.type_verification<f16>
+
+// -----
+
+// CHECK: "test.type_producer"() : () -> vector<!ptr.ptr<5 : i64>>
+"test.type_producer"() : () -> vector<!ptr.ptr<5>>
+
+// -----
+
+// CHECK: "test.type_producer"() : () -> vector<!llvm.ptr<1>>
+"test.type_producer"() : () -> vector<!llvm.ptr<1>>
+
+// -----
+
+// expected-error @below{{failed to verify 'elementType': integer or index or floating-point or pointer-like}}
+"test.type_producer"() : () -> vector<memref<2xf32>>
More information about the Mlir-commits
mailing list