[Mlir-commits] [mlir] c4cfc95 - [mlir][SPIRV] Add decorateType method for MatrixType (#112018)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 25 08:21:05 PDT 2025


Author: Mingzhu Yan
Date: 2025-05-25T08:21:01-07:00
New Revision: c4cfc95d76f250943e2a8c589afb1658ff1d1524

URL: https://github.com/llvm/llvm-project/commit/c4cfc95d76f250943e2a8c589afb1658ff1d1524
DIFF: https://github.com/llvm/llvm-project/commit/c4cfc95d76f250943e2a8c589afb1658ff1d1524.diff

LOG: [mlir][SPIRV] Add decorateType method for MatrixType (#112018)

Fixes #108161 

This PR adds a decorateType method for MatrixType, ensuring that
`spirv.matrix` with offset in `spirv.struct` can be handled correctly.

Signed-off-by: MingZhu Yan <yanmingzhu at iscas.ac.cn>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
    mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
    mlir/test/Dialect/SPIRV/IR/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2d..72683d50d7411 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
 class ArrayType;
 class RuntimeArrayType;
 class StructType;
+class MatrixType;
 } // namespace spirv
 
 /// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
   static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
   static Type decorateType(spirv::ArrayType arrayType, Size &size,
                            Size &alignment);
+  static Type decorateType(spirv::MatrixType matrixType, Size &size,
+                           Size &alignment);
   static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
   static spirv::StructType decorateType(spirv::StructType structType,
                                         Size &size, Size &alignment);

diff  --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc37445..51cfe4a68eb2d 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
     return decorateType(arrayType, size, alignment);
   if (auto vectorType = dyn_cast<VectorType>(type))
     return decorateType(vectorType, size, alignment);
+  if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+    return decorateType(matrixType, size, alignment);
   if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
     size = std::numeric_limits<Size>().max();
     return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
   return spirv::ArrayType::get(memberType, numElements, elementSize);
 }
 
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+                                     VulkanLayoutUtils::Size &size,
+                                     VulkanLayoutUtils::Size &alignment) {
+  const unsigned numColumns = matrixType.getNumColumns();
+  Type columnType = matrixType.getColumnType();
+  unsigned numElements = matrixType.getNumElements();
+  Type elementType = matrixType.getElementType();
+  Size elementSize = 0;
+  Size elementAlignment = 1;
+
+  decorateType(elementType, elementSize, elementAlignment);
+  // According to the Vulkan spec:
+  // "A matrix type inherits scalar alignment from the equivalent array
+  // declaration."
+  size = elementSize * numElements;
+  alignment = elementAlignment;
+  return spirv::MatrixType::get(columnType, numColumns);
+}
+
 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
                                      VulkanLayoutUtils::Size &alignment) {
   auto elementType = arrayType.getElementType();

diff  --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd..b63a08d96e6af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
 
 // -----
 
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
 // expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
 func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
 


        


More information about the Mlir-commits mailing list