[Mlir-commits] [mlir] [mlir][SPIRV] Add decorateType method for MatrixType (PR #112018)

MingZhu Yan llvmlistbot at llvm.org
Wed Oct 30 17:55:57 PDT 2024


https://github.com/trdthg updated https://github.com/llvm/llvm-project/pull/112018

>From cd90206725525140c4d460e2cdb053cdf71c21f1 Mon Sep 17 00:00:00 2001
From: MingZhu Yan <yanmingzhu at iscas.ac.cn>
Date: Sat, 12 Oct 2024 11:16:15 +0800
Subject: [PATCH] [mlir][SPIRV] Add decorateType method for MatrixType

This PR adds a decorateType method for MatrixType,
ensuring that `spirv.matrix` with offset in `spirv.struct`
can be handled correctly.

Signed-off-by: MingZhu Yan <yanmingzhu at iscas.ac.cn>
---
 .../mlir/Dialect/SPIRV/Utils/LayoutUtils.h    |  3 +++
 mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp  | 21 +++++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/types.mlir         |  5 +++++
 3 files changed, 29 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2da..72683d50d74117 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
 class ArrayType;
 class RuntimeArrayType;
 class StructType;
+class MatrixType;
 } // namespace spirv
 
 /// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
   static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
   static Type decorateType(spirv::ArrayType arrayType, Size &size,
                            Size &alignment);
+  static Type decorateType(spirv::MatrixType matrixType, Size &size,
+                           Size &alignment);
   static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
   static spirv::StructType decorateType(spirv::StructType structType,
                                         Size &size, Size &alignment);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc374452..48124d153d58cd 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
     return decorateType(arrayType, size, alignment);
   if (auto vectorType = dyn_cast<VectorType>(type))
     return decorateType(vectorType, size, alignment);
+  if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+    return decorateType(matrixType, size, alignment);
   if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
     size = std::numeric_limits<Size>().max();
     return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
   return spirv::ArrayType::get(memberType, numElements, elementSize);
 }
 
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+                                     VulkanLayoutUtils::Size &size,
+                                     VulkanLayoutUtils::Size &alignment) {
+  const unsigned numColumns = matrixType.getNumColumns();
+  const unsigned columnType = matrixType.getColumnType();
+  const unsigned numElements = matrixType.getNumElements();
+  Type elementType = matrixType.getElementType();
+  Size elementSize = 0;
+  Size elementAlignment = 1;
+
+  decorateType(elementType, elementSize, elementAlignment);
+  // According to the Vulkan spec:
+  // "A matrix type inherits scalar alignment from the equivalent array
+  // declaration."
+  size = elementSize * numElements;
+  alignment = elementAlignment;
+  return spirv::MatrixType::get(columnType, numColumns);
+}
+
 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
                                      VulkanLayoutUtils::Size &alignment) {
   auto elementType = arrayType.getElementType();
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd9..b63a08d96e6af9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
 
 // -----
 
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
 // expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
 func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
 



More information about the Mlir-commits mailing list