[Mlir-commits] [mlir] [mlir][IR] Experiment: Allow ptr as vector element type (PR #125690)

Matthias Springer llvmlistbot at llvm.org
Tue Feb 4 06:21:06 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/125690

None

>From 5b7422ad037d11f3c7a82f8d0b5007f4f2b3eee9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 4 Feb 2025 15:20:01 +0100
Subject: [PATCH] [mlir][IR] Experiment: Allow ptr as vector element type

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h   |  1 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td  |  4 +++-
 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td |  4 +++-
 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h    |  1 +
 mlir/include/mlir/IR/BuiltinTypes.h            |  4 ++++
 mlir/include/mlir/IR/BuiltinTypes.td           |  7 ++++++-
 mlir/include/mlir/IR/CommonTypeConstraints.td  |  2 ++
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp       |  3 ++-
 mlir/test/IR/test-verifiers-type.mlir          | 15 +++++++++++++++
 9 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 8b380751c2f9d6..bca0feb45aab2a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 77c8035ce3d71a..7c5bea568c8399 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 
@@ -257,7 +258,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
 
 def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
-      "getIndexBitwidth", "areCompatible", "verifyEntries"]>]> {
+      "getIndexBitwidth", "areCompatible", "verifyEntries"]>,
+    PointerLike]> {
   let summary = "LLVM pointer type";
   let description = [{
     The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 14d72c3001d919..a849960fd72eb6 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -12,6 +12,7 @@
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
 include "mlir/IR/OpBase.td"
 
 //===----------------------------------------------------------------------===//
@@ -38,7 +39,8 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
 def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
-      "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+      "areCompatible", "getIndexBitwidth", "verifyEntries"]>,
+    PointerLike
   ]> {
   let summary = "pointer type";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
index 264a97c80722a2..4b86eaba9eed25 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
 #define MLIR_DIALECT_PTR_IR_PTRTYPES_H
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d2..7f20878d02f677 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,6 +43,10 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
+/// Type trait indicating that the type is a pointer-like type.
+template <typename ConcreteType>
+class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
+
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9a..b066271e97d4ac 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
   let cppNamespace = "::mlir";
 }
 
+/// Type trait indicating that the type is a pointer-like type.
+def PointerLike : NativeTypeTrait<"PointerLike"> {
+  let cppNamespace = "::mlir";
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexType
 //===----------------------------------------------------------------------===//
@@ -1238,7 +1243,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
   let cppFunctionName = "isValidVectorTypeElementType";
 }
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 82e335e30b6fa4..6f41daae621947 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
                  "::mlir::IndexType">,
             BuildableType<"$_builder.getIndexType()">;
 
+def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
+
 // Any signless integer type or index type.
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 453b206de294e4..c57d0bcf2e5559 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -876,7 +876,8 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
 }
 
 bool mlir::LLVM::isCompatibleVectorType(Type type) {
-  if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
+  if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, LLVMPointerType>(
+          type))
     return true;
 
   if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 96d0005eb7a19d..eb88a259b64af0 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -7,3 +7,18 @@
 
 // expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
 "test.type_producer"() : () -> !test.type_verification<f16>
+
+// -----
+
+// CHECK: "test.type_producer"() : () -> vector<!ptr.ptr<5 : i64>>
+"test.type_producer"() : () -> vector<!ptr.ptr<5>>
+
+// -----
+
+// CHECK: "test.type_producer"() : () -> vector<!llvm.ptr<1>>
+"test.type_producer"() : () -> vector<!llvm.ptr<1>>
+
+// -----
+
+// expected-error @below{{failed to verify 'elementType': integer or index or floating-point or pointer-like}}
+"test.type_producer"() : () -> vector<memref<2xf32>>



More information about the Mlir-commits mailing list