[Mlir-commits] [mlir] [mlir][spirv] Rework type size calculation (PR #160162)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Sep 22 11:17:13 PDT 2025
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/160162
Similar to `::getExtensions` and `::getCapabilities`, introduce a single entry point for type size calculation.
Also fix potential infinite recursion with `StructType`s (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.
>From 511c8a53f3b52c5d4e46d65ecaab774f785aaa7d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 22 Sep 2025 14:13:45 -0400
Subject: [PATCH] [mlir][spirv] Rework type size calculation
Similar to `::getExtensions` and `::getCapabilities`, introduce a single
entry point for type size calculation.
Also fix potential infinite recursion with `StructType`s (even
non-recursive structs), although I don't know to write a test for this
without using C++. This is mostly an NFC modulo this potential bug fix.
---
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 8 --
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 78 +++++++------------
2 files changed, 30 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 475e3f495e065..e46b576810316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
static bool isValid(FloatType);
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
-
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
@@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
/// Return true if the number of elements is known at compile time and is not
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
-
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V array type
@@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
-
- /// Returns the array size in bytes. Since array type may have an explicit
- /// stride declaration (in bytes), we also include it in the calculation.
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V image type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7c2f43bea9ddb..5ed7652987859 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/ErrorHandling.h"
#include <cstdint>
+#include <optional>
using namespace mlir;
using namespace mlir::spirv;
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
-std::optional<int64_t> ArrayType::getSizeInBytes() {
- auto elementType = llvm::cast<SPIRVType>(getElementType());
- std::optional<int64_t> size = elementType.getSizeInBytes();
- if (!size)
- return std::nullopt;
- return (*size + getArrayStride()) * getNumElements();
-}
-
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
}
}
-std::optional<int64_t> CompositeType::getSizeInBytes() {
- if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
- return arrayType.getSizeInBytes();
- if (auto structType = llvm::dyn_cast<StructType>(*this))
- return structType.getSizeInBytes();
- if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
- std::optional<int64_t> elementSize =
- llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
- if (!elementSize)
- return std::nullopt;
- return *elementSize * vectorType.getNumElements();
- }
- if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- std::optional<int64_t> elementSize =
- llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
- if (!elementSize)
- return std::nullopt;
- return *elementSize * tensorArmType.getNumElements();
- }
- return std::nullopt;
-}
-
//===----------------------------------------------------------------------===//
// CooperativeMatrixType
//===----------------------------------------------------------------------===//
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
#undef WIDTH_CASE
}
-std::optional<int64_t> ScalarType::getSizeInBytes() {
- auto bitWidth = getIntOrFloatBitWidth();
- // According to the SPIR-V spec:
- // "There is no physical size or bit pattern defined for values with boolean
- // type. If they are stored (in conjunction with OpVariable), they can only
- // be used with logical addressing operations, not physical, and only with
- // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
- // Private, Function, Input, and Output."
- if (bitWidth == 1)
- return std::nullopt;
- return bitWidth / 8;
-}
-
//===----------------------------------------------------------------------===//
// SPIRVType
//===----------------------------------------------------------------------===//
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
- return scalarType.getSizeInBytes();
- if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
- return compositeType.getSizeInBytes();
- return std::nullopt;
+ return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
+ .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
+ // According to the SPIR-V spec:
+ // "There is no physical size or bit pattern defined for values with
+ // boolean type. If they are stored (in conjunction with OpVariable),
+ // they can only be used with logical addressing operations, not
+ // physical, and only with non-externally visible shader Storage
+ // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
+ // Output."
+ int64_t bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 1)
+ return std::nullopt;
+ return bitWidth / 8;
+ })
+ .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
+ // Since array type may have an explicit stride declaration (in bytes),
+ // we also include it in the calculation.
+ auto elementType = cast<SPIRVType>(type.getElementType());
+ if (std::optional<int64_t> size = elementType.getSizeInBytes())
+ return (*size + type.getArrayStride()) * type.getNumElements();
+ return std::nullopt;
+ })
+ .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
+ if(std::optional<int64_t> elementSize =
+ cast<ScalarType>(type.getElementType()).getSizeInBytes())
+ return *elementSize * type.getNumElements();
+ return std::nullopt;
+ })
+ .Default(std::optional<int64_t>());
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list