[Mlir-commits] [mlir] [mlir][SPIRV] Add decorateType method for MatrixType (PR #112018)
MingZhu Yan
llvmlistbot at llvm.org
Wed Oct 30 17:55:57 PDT 2024
https://github.com/trdthg updated https://github.com/llvm/llvm-project/pull/112018
>From cd90206725525140c4d460e2cdb053cdf71c21f1 Mon Sep 17 00:00:00 2001
From: MingZhu Yan <yanmingzhu at iscas.ac.cn>
Date: Sat, 12 Oct 2024 11:16:15 +0800
Subject: [PATCH] [mlir][SPIRV] Add decorateType method for MatrixType
This PR adds a decorateType method for MatrixType,
ensuring that `spirv.matrix` with offset in `spirv.struct`
can be handled correctly.
Signed-off-by: MingZhu Yan <yanmingzhu at iscas.ac.cn>
---
.../mlir/Dialect/SPIRV/Utils/LayoutUtils.h | 3 +++
mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp | 21 +++++++++++++++++++
mlir/test/Dialect/SPIRV/IR/types.mlir | 5 +++++
3 files changed, 29 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2da..72683d50d74117 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
class ArrayType;
class RuntimeArrayType;
class StructType;
+class MatrixType;
} // namespace spirv
/// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
static Type decorateType(spirv::ArrayType arrayType, Size &size,
Size &alignment);
+ static Type decorateType(spirv::MatrixType matrixType, Size &size,
+ Size &alignment);
static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
static spirv::StructType decorateType(spirv::StructType structType,
Size &size, Size &alignment);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc374452..48124d153d58cd 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
return decorateType(arrayType, size, alignment);
if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
+ if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+ return decorateType(matrixType, size, alignment);
if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
return spirv::ArrayType::get(memberType, numElements, elementSize);
}
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+ VulkanLayoutUtils::Size &size,
+ VulkanLayoutUtils::Size &alignment) {
+ const unsigned numColumns = matrixType.getNumColumns();
+ const unsigned columnType = matrixType.getColumnType();
+ const unsigned numElements = matrixType.getNumElements();
+ Type elementType = matrixType.getElementType();
+ Size elementSize = 0;
+ Size elementAlignment = 1;
+
+ decorateType(elementType, elementSize, elementAlignment);
+ // According to the Vulkan spec:
+ // "A matrix type inherits scalar alignment from the equivalent array
+ // declaration."
+ size = elementSize * numElements;
+ alignment = elementAlignment;
+ return spirv::MatrixType::get(columnType, numColumns);
+}
+
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
VulkanLayoutUtils::Size &alignment) {
auto elementType = arrayType.getElementType();
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd9..b63a08d96e6af9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
// -----
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
More information about the Mlir-commits
mailing list