[Mlir-commits] [mlir] [mlir][SPIRV] Add decorateType method for MatrixType (PR #112018)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 11 09:26:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: MingZhu Yan (trdthg)

<details>
<summary>Changes</summary>

try fix #<!-- -->108161 

This PR adds a decorateType method for MatrixType, ensuring that `spirv.matrix` with offset in `spirv.struct` can be handled correctly.

---
Full diff: https://github.com/llvm/llvm-project/pull/112018.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h (+3) 
- (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+21) 
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2da..72683d50d74117 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
 class ArrayType;
 class RuntimeArrayType;
 class StructType;
+class MatrixType;
 } // namespace spirv
 
 /// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
   static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
   static Type decorateType(spirv::ArrayType arrayType, Size &size,
                            Size &alignment);
+  static Type decorateType(spirv::MatrixType matrixType, Size &size,
+                           Size &alignment);
   static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
   static spirv::StructType decorateType(spirv::StructType structType,
                                         Size &size, Size &alignment);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc374452..ede9397fbc552e 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
     return decorateType(arrayType, size, alignment);
   if (auto vectorType = dyn_cast<VectorType>(type))
     return decorateType(vectorType, size, alignment);
+  if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+    return decorateType(matrixType, size, alignment);
   if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
     size = std::numeric_limits<Size>().max();
     return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
   return spirv::ArrayType::get(memberType, numElements, elementSize);
 }
 
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+                                     VulkanLayoutUtils::Size &size,
+                                     VulkanLayoutUtils::Size &alignment) {
+  const auto numColumns = matrixType.getNumColumns();
+  const auto columnType = matrixType.getColumnType();
+  const auto numElements = matrixType.getNumElements();
+  auto elementType = matrixType.getElementType();
+  Size elementSize = 0;
+  Size elementAlignment = 1;
+
+  decorateType(elementType, elementSize, elementAlignment);
+  // According to the Vulkan spec:
+  // "A matrix type inherits scalar alignment from the equivalent array
+  // declaration.
+  size = elementSize * numElements;
+  alignment = elementAlignment;
+  return spirv::MatrixType::get(columnType, numColumns);
+}
+
 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
                                      VulkanLayoutUtils::Size &alignment) {
   auto elementType = arrayType.getElementType();
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd9..b63a08d96e6af9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
 
 // -----
 
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
 // expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
 func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/112018


More information about the Mlir-commits mailing list