[Mlir-commits] [mlir] [mlir][spirv] Rework type size calculation (PR #160162)

Jakub Kuderski llvmlistbot at llvm.org
Mon Sep 22 11:17:13 PDT 2025


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/160162

Similar to `::getExtensions` and `::getCapabilities`, introduce a single entry point for type size calculation.

Also fix potential infinite recursion with `StructType`s (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.

>From 511c8a53f3b52c5d4e46d65ecaab774f785aaa7d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 22 Sep 2025 14:13:45 -0400
Subject: [PATCH] [mlir][spirv] Rework type size calculation

Similar to `::getExtensions` and `::getCapabilities`, introduce a single
entry point for type size calculation.

Also fix potential infinite recursion with `StructType`s (even
non-recursive structs), although I don't know to write a test for this
without using C++. This is mostly an NFC modulo this potential bug fix.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  8 --
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 78 +++++++------------
 2 files changed, 30 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 475e3f495e065..e46b576810316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
   static bool isValid(FloatType);
   /// Returns true if the given float type is valid for the SPIR-V dialect.
   static bool isValid(IntegerType);
-
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
@@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
   /// Return true if the number of elements is known at compile time and is not
   /// implementation dependent.
   bool hasCompileTimeKnownNumElements() const;
-
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V array type
@@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
   /// Returns the array stride in bytes. 0 means no stride decorated on this
   /// type.
   unsigned getArrayStride() const;
-
-  /// Returns the array size in bytes. Since array type may have an explicit
-  /// stride declaration (in bytes), we also include it in the calculation.
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V image type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7c2f43bea9ddb..5ed7652987859 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/ErrorHandling.h"
 
 #include <cstdint>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
-std::optional<int64_t> ArrayType::getSizeInBytes() {
-  auto elementType = llvm::cast<SPIRVType>(getElementType());
-  std::optional<int64_t> size = elementType.getSizeInBytes();
-  if (!size)
-    return std::nullopt;
-  return (*size + getArrayStride()) * getNumElements();
-}
-
 //===----------------------------------------------------------------------===//
 // CompositeType
 //===----------------------------------------------------------------------===//
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
   }
 }
 
-std::optional<int64_t> CompositeType::getSizeInBytes() {
-  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
-    return arrayType.getSizeInBytes();
-  if (auto structType = llvm::dyn_cast<StructType>(*this))
-    return structType.getSizeInBytes();
-  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
-    std::optional<int64_t> elementSize =
-        llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
-    if (!elementSize)
-      return std::nullopt;
-    return *elementSize * vectorType.getNumElements();
-  }
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    std::optional<int64_t> elementSize =
-        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
-    if (!elementSize)
-      return std::nullopt;
-    return *elementSize * tensorArmType.getNumElements();
-  }
-  return std::nullopt;
-}
-
 //===----------------------------------------------------------------------===//
 // CooperativeMatrixType
 //===----------------------------------------------------------------------===//
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
 #undef WIDTH_CASE
 }
 
-std::optional<int64_t> ScalarType::getSizeInBytes() {
-  auto bitWidth = getIntOrFloatBitWidth();
-  // According to the SPIR-V spec:
-  // "There is no physical size or bit pattern defined for values with boolean
-  // type. If they are stored (in conjunction with OpVariable), they can only
-  // be used with logical addressing operations, not physical, and only with
-  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
-  // Private, Function, Input, and Output."
-  if (bitWidth == 1)
-    return std::nullopt;
-  return bitWidth / 8;
-}
-
 //===----------------------------------------------------------------------===//
 // SPIRVType
 //===----------------------------------------------------------------------===//
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
 }
 
 std::optional<int64_t> SPIRVType::getSizeInBytes() {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
-    return scalarType.getSizeInBytes();
-  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
-    return compositeType.getSizeInBytes();
-  return std::nullopt;
+  return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
+      .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
+        // According to the SPIR-V spec:
+        // "There is no physical size or bit pattern defined for values with
+        // boolean type. If they are stored (in conjunction with OpVariable),
+        // they can only be used with logical addressing operations, not
+        // physical, and only with non-externally visible shader Storage
+        // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
+        // Output."
+        int64_t bitWidth = type.getIntOrFloatBitWidth();
+        if (bitWidth == 1)
+          return std::nullopt;
+        return bitWidth / 8;
+      })
+      .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
+        // Since array type may have an explicit stride declaration (in bytes),
+        // we also include it in the calculation.
+        auto elementType = cast<SPIRVType>(type.getElementType());
+        if (std::optional<int64_t> size = elementType.getSizeInBytes())
+          return (*size + type.getArrayStride()) * type.getNumElements();
+        return std::nullopt;
+      })
+      .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
+        if(std::optional<int64_t> elementSize =
+                cast<ScalarType>(type.getElementType()).getSizeInBytes())
+          return *elementSize * type.getNumElements();
+        return std::nullopt;
+      })
+      .Default(std::optional<int64_t>());
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list