[Mlir-commits] [mlir] c4cfc95 - [mlir][SPIRV] Add decorateType method for MatrixType (#112018)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 25 08:21:05 PDT 2025
Author: Mingzhu Yan
Date: 2025-05-25T08:21:01-07:00
New Revision: c4cfc95d76f250943e2a8c589afb1658ff1d1524
URL: https://github.com/llvm/llvm-project/commit/c4cfc95d76f250943e2a8c589afb1658ff1d1524
DIFF: https://github.com/llvm/llvm-project/commit/c4cfc95d76f250943e2a8c589afb1658ff1d1524.diff
LOG: [mlir][SPIRV] Add decorateType method for MatrixType (#112018)
Fixes #108161
This PR adds a decorateType method for MatrixType, ensuring that
`spirv.matrix` with offset in `spirv.struct` can be handled correctly.
Signed-off-by: MingZhu Yan <yanmingzhu at iscas.ac.cn>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
mlir/test/Dialect/SPIRV/IR/types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2d..72683d50d7411 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
class ArrayType;
class RuntimeArrayType;
class StructType;
+class MatrixType;
} // namespace spirv
/// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
static Type decorateType(spirv::ArrayType arrayType, Size &size,
Size &alignment);
+ static Type decorateType(spirv::MatrixType matrixType, Size &size,
+ Size &alignment);
static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
static spirv::StructType decorateType(spirv::StructType structType,
Size &size, Size &alignment);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc37445..51cfe4a68eb2d 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
return decorateType(arrayType, size, alignment);
if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
+ if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+ return decorateType(matrixType, size, alignment);
if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
return spirv::ArrayType::get(memberType, numElements, elementSize);
}
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+ VulkanLayoutUtils::Size &size,
+ VulkanLayoutUtils::Size &alignment) {
+ const unsigned numColumns = matrixType.getNumColumns();
+ Type columnType = matrixType.getColumnType();
+ unsigned numElements = matrixType.getNumElements();
+ Type elementType = matrixType.getElementType();
+ Size elementSize = 0;
+ Size elementAlignment = 1;
+
+ decorateType(elementType, elementSize, elementAlignment);
+ // According to the Vulkan spec:
+ // "A matrix type inherits scalar alignment from the equivalent array
+ // declaration."
+ size = elementSize * numElements;
+ alignment = elementAlignment;
+ return spirv::MatrixType::get(columnType, numColumns);
+}
+
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
VulkanLayoutUtils::Size &alignment) {
auto elementType = arrayType.getElementType();
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd..b63a08d96e6af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
// -----
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
More information about the Mlir-commits
mailing list