[llvm] 90cdc03 - [IR] Fix undiagnosed cases of structs containing scalable vectors (#113455)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 25 04:56:14 PDT 2024
Author: Jay Foad
Date: 2024-10-25T12:56:10+01:00
New Revision: 90cdc03e7f5bda2e31573d48450a8ac8fa856efa
URL: https://github.com/llvm/llvm-project/commit/90cdc03e7f5bda2e31573d48450a8ac8fa856efa
DIFF: https://github.com/llvm/llvm-project/commit/90cdc03e7f5bda2e31573d48450a8ac8fa856efa.diff
LOG: [IR] Fix undiagnosed cases of structs containing scalable vectors (#113455)
Type::isScalableTy and StructType::containsScalableVectorType failed to
detect some cases of structs containing scalable vectors because
containsScalableVectorType did not call back into isScalableTy to check
the element types. Fix this, which requires sharing the same Visited set
in both functions. Also change the external API so that callers are
never required to pass in a Visited set, and normalize the naming to
isScalableTy.
Added:
Modified:
llvm/include/llvm/IR/DerivedTypes.h
llvm/include/llvm/IR/Type.h
llvm/lib/AsmParser/LLParser.cpp
llvm/lib/IR/Type.cpp
llvm/lib/IR/Verifier.cpp
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
llvm/test/Verifier/scalable-global-vars.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index a24801d8bdf834..820b5c0707df6c 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -290,8 +290,8 @@ class StructType : public Type {
bool isSized(SmallPtrSetImpl<Type *> *Visited = nullptr) const;
/// Returns true if this struct contains a scalable vector.
- bool
- containsScalableVectorType(SmallPtrSetImpl<Type *> *Visited = nullptr) const;
+ bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
+ using Type::isScalableTy;
/// Returns true if this struct contains homogeneous scalable vector types.
/// Note that the definition of homogeneous scalable vector type is not
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 2f53197df19998..d563b25d600a0c 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -206,6 +206,7 @@ class Type {
bool isScalableTargetExtTy() const;
/// Return true if this is a type whose size is a known multiple of vscale.
+ bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
bool isScalableTy() const;
/// Return true if this is a FP type or a vector of FP.
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 6a2372c9751408..8ddb2efb0e26c2 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -8525,7 +8525,7 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
return error(Loc, "base element of getelementptr must be sized");
auto *STy = dyn_cast<StructType>(Ty);
- if (STy && STy->containsScalableVectorType())
+ if (STy && STy->isScalableTy())
return error(Loc, "getelementptr cannot target structure that contains "
"scalable vector type");
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index f618263f79c313..912b1a3960ef19 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -58,16 +58,19 @@ bool Type::isIntegerTy(unsigned Bitwidth) const {
return isIntegerTy() && cast<IntegerType>(this)->getBitWidth() == Bitwidth;
}
-bool Type::isScalableTy() const {
+bool Type::isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const {
if (const auto *ATy = dyn_cast<ArrayType>(this))
- return ATy->getElementType()->isScalableTy();
- if (const auto *STy = dyn_cast<StructType>(this)) {
- SmallPtrSet<Type *, 4> Visited;
- return STy->containsScalableVectorType(&Visited);
- }
+ return ATy->getElementType()->isScalableTy(Visited);
+ if (const auto *STy = dyn_cast<StructType>(this))
+ return STy->isScalableTy(Visited);
return getTypeID() == ScalableVectorTyID || isScalableTargetExtTy();
}
+bool Type::isScalableTy() const {
+ SmallPtrSet<const Type *, 4> Visited;
+ return isScalableTy(Visited);
+}
+
const fltSemantics &Type::getFltSemantics() const {
switch (getTypeID()) {
case HalfTyID: return APFloat::IEEEhalf();
@@ -394,30 +397,22 @@ StructType *StructType::get(LLVMContext &Context, ArrayRef<Type*> ETypes,
return ST;
}
-bool StructType::containsScalableVectorType(
- SmallPtrSetImpl<Type *> *Visited) const {
+bool StructType::isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const {
if ((getSubclassData() & SCDB_ContainsScalableVector) != 0)
return true;
if ((getSubclassData() & SCDB_NotContainsScalableVector) != 0)
return false;
- if (Visited && !Visited->insert(const_cast<StructType *>(this)).second)
+ if (!Visited.insert(this).second)
return false;
for (Type *Ty : elements()) {
- if (isa<ScalableVectorType>(Ty)) {
+ if (Ty->isScalableTy(Visited)) {
const_cast<StructType *>(this)->setSubclassData(
getSubclassData() | SCDB_ContainsScalableVector);
return true;
}
- if (auto *STy = dyn_cast<StructType>(Ty)) {
- if (STy->containsScalableVectorType(Visited)) {
- const_cast<StructType *>(this)->setSubclassData(
- getSubclassData() | SCDB_ContainsScalableVector);
- return true;
- }
- }
}
// For structures that are opaque, return false but do not set the
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index f34fe7594c8602..60e65392218dad 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -4107,8 +4107,7 @@ void Verifier::visitGetElementPtrInst(GetElementPtrInst &GEP) {
Check(GEP.getSourceElementType()->isSized(), "GEP into unsized type!", &GEP);
if (auto *STy = dyn_cast<StructType>(GEP.getSourceElementType())) {
- SmallPtrSet<Type *, 4> Visited;
- Check(!STy->containsScalableVectorType(&Visited),
+ Check(!STy->isScalableTy(),
"getelementptr cannot target structure that contains scalable vector"
"type",
&GEP);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index c8b9f166b16020..971ace2a4f4716 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4087,7 +4087,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
if (LoadInst *L = dyn_cast<LoadInst>(Agg)) {
// Bail out if the aggregate contains scalable vector type
if (auto *STy = dyn_cast<StructType>(Agg->getType());
- STy && STy->containsScalableVectorType())
+ STy && STy->isScalableTy())
return nullptr;
// If the (non-volatile) load only has one use, we can rewrite this to a
diff --git a/llvm/test/Verifier/scalable-global-vars.ll b/llvm/test/Verifier/scalable-global-vars.ll
index 81882261e664ef..fb9a3067acba98 100644
--- a/llvm/test/Verifier/scalable-global-vars.ll
+++ b/llvm/test/Verifier/scalable-global-vars.ll
@@ -15,3 +15,17 @@
; CHECK-NEXT: ptr @ScalableVecStructGlobal
@ScalableVecStructGlobal = global { i32, <vscale x 4 x i32> } zeroinitializer
+; CHECK-NEXT: Globals cannot contain scalable types
+; CHECK-NEXT: ptr @StructTestGlobal
+%struct.test = type { <vscale x 1 x double>, <vscale x 1 x double> }
+ at StructTestGlobal = global %struct.test zeroinitializer
+
+; CHECK-NEXT: Globals cannot contain scalable types
+; CHECK-NEXT: ptr @StructArrayTestGlobal
+%struct.array.test = type { [2 x <vscale x 1 x double>] }
+ at StructArrayTestGlobal = global %struct.array.test zeroinitializer
+
+; CHECK-NEXT: Globals cannot contain scalable types
+; CHECK-NEXT: ptr @StructTargetTestGlobal
+%struct.target.test = type { target("aarch64.svcount"), target("aarch64.svcount") }
+ at StructTargetTestGlobal = global %struct.target.test zeroinitializer
More information about the llvm-commits
mailing list