[llvm] [SPIR-V] Fix isAggregateType function implementation (PR #187685)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 20 05:08:49 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Arseniy Obolenskiy (aobolensk)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/187685.diff


3 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+1-1) 
- (modified) llvm/unittests/Target/SPIRV/CMakeLists.txt (+1) 
- (added) llvm/unittests/Target/SPIRV/SPIRVGlobalRegistryTests.cpp (+73) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7296f2c0bf351..a77dd7e3bc265 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -369,7 +369,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   // Return true if the type is an aggregate type.
   bool isAggregateType(SPIRVTypeInst Type) const {
-    return Type && (Type->getOpcode() == SPIRV::OpTypeStruct &&
+    return Type && (Type->getOpcode() == SPIRV::OpTypeStruct ||
                     Type->getOpcode() == SPIRV::OpTypeArray);
   }
 
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index 29b31b16094a0..4740044dfcb84 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -17,6 +17,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_target_unittest(SPIRVTests
   SPIRVConvergenceRegionAnalysisTests.cpp
+  SPIRVGlobalRegistryTests.cpp
   SPIRVSortBlocksTests.cpp
   SPIRVPartialOrderingVisitorTests.cpp
   SPIRVAPITest.cpp
diff --git a/llvm/unittests/Target/SPIRV/SPIRVGlobalRegistryTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVGlobalRegistryTests.cpp
new file mode 100644
index 0000000000000..80364def71367
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVGlobalRegistryTests.cpp
@@ -0,0 +1,73 @@
+//===- SPIRVGlobalRegistryTests.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVGlobalRegistry.h"
+#include "SPIRVInstrInfo.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/IR/Module.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+class SPIRVGlobalRegistryTest : public testing::Test {
+protected:
+  static void SetUpTestSuite() {
+    LLVMInitializeSPIRVTargetInfo();
+    LLVMInitializeSPIRVTarget();
+    LLVMInitializeSPIRVTargetMC();
+  }
+
+  void SetUp() override {
+    Triple TT("spirv64-unknown-unknown");
+    std::string Error;
+    const Target *T = TargetRegistry::lookupTarget(TT, Error);
+    if (!T)
+      GTEST_SKIP();
+    TargetOptions Options;
+    TM.reset(T->createTargetMachine(TT, "", "", Options, std::nullopt,
+                                    std::nullopt));
+    Ctx = std::make_unique<LLVMContext>();
+    Mod = std::make_unique<Module>("M", *Ctx);
+    Mod->setDataLayout(TM->createDataLayout());
+    auto *F = Function::Create(FunctionType::get(Type::getVoidTy(*Ctx), false),
+                               GlobalValue::ExternalLinkage, "f", *Mod);
+    MMI = std::make_unique<MachineModuleInfo>(TM.get());
+    MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
+                                           MMI->getContext(), 0);
+    MBB = MF->CreateMachineBasicBlock();
+    MF->push_back(MBB);
+  }
+
+  SPIRVTypeInst makeTypeInstr(unsigned Opcode) {
+    auto &TII =
+        *static_cast<const SPIRVInstrInfo *>(MF->getSubtarget().getInstrInfo());
+    Register Reg = MF->getRegInfo().createVirtualRegister(&SPIRV::TYPERegClass);
+    return BuildMI(*MBB, MBB->end(), DebugLoc(), TII.get(Opcode))
+        .addDef(Reg)
+        .getInstr();
+  }
+
+  std::unique_ptr<TargetMachine> TM;
+  std::unique_ptr<LLVMContext> Ctx;
+  std::unique_ptr<Module> Mod;
+  std::unique_ptr<MachineModuleInfo> MMI;
+  std::unique_ptr<MachineFunction> MF;
+  MachineBasicBlock *MBB = nullptr;
+};
+
+TEST_F(SPIRVGlobalRegistryTest, IsAggregateType) {
+  SPIRVGlobalRegistry GR(8);
+  EXPECT_TRUE(GR.isAggregateType(makeTypeInstr(SPIRV::OpTypeStruct)));
+  EXPECT_TRUE(GR.isAggregateType(makeTypeInstr(SPIRV::OpTypeArray)));
+  EXPECT_FALSE(GR.isAggregateType(makeTypeInstr(SPIRV::OpTypeFloat)));
+  EXPECT_FALSE(GR.isAggregateType(SPIRVTypeInst(nullptr)));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/187685


More information about the llvm-commits mailing list