[llvm] [SPIR-V] Fix bad insertion for type/id MIR (PR #109686)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 05:36:42 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/109686

>From 372611e186fc7eaea4f94c3ee7516d5d0e655dcd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Fri, 20 Sep 2024 17:33:58 +0200
Subject: [PATCH 1/2] [SPIR-V] Fix bad insertion for type/id MIR
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Those instructions were inserted either after the instruction
using it, or in the middle of the module.
The first directly causes an issue. The second causes a more subtle
issue: the first type the type is inserted, the emission is fine, but
the second times, the first instruction is reused, without checking its
position in the function. This can lead to the second usage dominating
the definition.

In SPIR-V, types are usually in the header, above all code definition,
but at this stage I don't think we can, so what I do instead is to emit it in
the first basic block.

This commit reduces the failed tests with expensive checks from 107 to
71.

Signed-off-by: Nathan Gauër <brioche at google.com>

another method
---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 124 ++++++++++++------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  12 ++
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |   4 +-
 3 files changed, 95 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ca3e47a4b78f23..509603fcd8bd55 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -25,6 +25,7 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
+#include <functional>
 
 using namespace llvm;
 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
@@ -83,8 +84,11 @@ inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
-  return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
-      .addDef(createTypeVReg(MIRBuilder));
+
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
+        .addDef(createTypeVReg(MIRBuilder));
+  });
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
@@ -118,24 +122,49 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
     MIRBuilder.buildInstr(SPIRV::OpCapability)
         .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
   }
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
-                 .addDef(createTypeVReg(MIRBuilder))
-                 .addImm(Width)
-                 .addImm(IsSigned ? 1 : 0);
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addImm(Width)
+        .addImm(IsSigned ? 1 : 0);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
                                                MachineIRBuilder &MIRBuilder) {
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
-                 .addDef(createTypeVReg(MIRBuilder))
-                 .addImm(Width);
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addImm(Width);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
-  return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
-      .addDef(createTypeVReg(MIRBuilder));
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
+        .addDef(createTypeVReg(MIRBuilder));
+  });
+}
+
+SPIRVType *SPIRVGlobalRegistry::createOpType(
+    MachineIRBuilder &MIRBuilder,
+    std::function<MachineInstr *(MachineIRBuilder &)> Op) {
+  auto oldInsertPoint = MIRBuilder.getInsertPt();
+  MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
+
+  if (LastInsertedType == nullptr) {
+    MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
+                           MIRBuilder.getMF().begin()->begin());
+  } else {
+    MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
+                           LastInsertedType->getIterator());
+  }
+
+  MachineInstr *Type = Op(MIRBuilder);
+  LastInsertedType = Type;
+
+  MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
+  return Type;
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
@@ -147,11 +176,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
           EleOpc == SPIRV::OpTypeBool) &&
          "Invalid vector element type");
 
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
-                 .addDef(createTypeVReg(MIRBuilder))
-                 .addUse(getSPIRVTypeID(ElemType))
-                 .addImm(NumElems);
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addUse(getSPIRVTypeID(ElemType))
+        .addImm(NumElems);
+  });
 }
 
 std::tuple<Register, ConstantInt *, bool, unsigned>
@@ -688,11 +718,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
   SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
   Register NumElementsVReg =
       buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
-                 .addDef(createTypeVReg(MIRBuilder))
-                 .addUse(getSPIRVTypeID(ElemType))
-                 .addUse(NumElementsVReg);
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addUse(getSPIRVTypeID(ElemType))
+        .addUse(NumElementsVReg);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
@@ -700,10 +731,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
   assert(Ty->hasName());
   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
   Register ResVReg = createTypeVReg(MIRBuilder);
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
-  addStringImm(Name, MIB);
-  buildOpName(ResVReg, Name, MIRBuilder);
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
+    addStringImm(Name, MIB);
+    buildOpName(ResVReg, Name, MIRBuilder);
+    return MIB;
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
@@ -717,14 +750,16 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
   }
   Register ResVReg = createTypeVReg(MIRBuilder);
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
-  for (const auto &Ty : FieldTypes)
-    MIB.addUse(Ty);
-  if (Ty->hasName())
-    buildOpName(ResVReg, Ty->getName(), MIRBuilder);
-  if (Ty->isPacked())
-    buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
-  return MIB;
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
+    for (const auto &Ty : FieldTypes)
+      MIB.addUse(Ty);
+    if (Ty->hasName())
+      buildOpName(ResVReg, Ty->getName(), MIRBuilder);
+    if (Ty->isPacked())
+      buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
+    return MIB;
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
@@ -739,17 +774,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
     MachineIRBuilder &MIRBuilder, Register Reg) {
   if (!Reg.isValid())
     Reg = createTypeVReg(MIRBuilder);
-  return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
-      .addDef(Reg)
-      .addImm(static_cast<uint32_t>(SC))
-      .addUse(getSPIRVTypeID(ElemType));
+
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
+        .addDef(Reg)
+        .addImm(static_cast<uint32_t>(SC))
+        .addUse(getSPIRVTypeID(ElemType));
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
-  return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
-      .addUse(createTypeVReg(MIRBuilder))
-      .addImm(static_cast<uint32_t>(SC));
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
+        .addUse(createTypeVReg(MIRBuilder))
+        .addImm(static_cast<uint32_t>(SC));
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ed9cfc07132430..4831e3437dda0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -64,6 +64,10 @@ class SPIRVGlobalRegistry {
   SmallPtrSet<const Type *, 4> TypesInProcessing;
   DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
 
+  // The last MIR inserted defining a SPIR-V Type.
+  // See: SPIRVGlobalRegistry::createOpType.
+  MachineInstr *LastInsertedType = nullptr;
+
   // if a function returns a pointer, this is to map it into TypedPointerType
   DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
 
@@ -97,6 +101,13 @@ class SPIRVGlobalRegistry {
                         SPIRV::AccessQualifier::AccessQualifier AccessQual,
                         bool EmitIR);
 
+  // Internal function creating the an OpType at the correct position in the
+  // function by tweaking the passed "MIRBuilder" insertion point and restoring
+  // it to the correct position. "Op" should be the function creating the
+  // specific OpType you need, and should return the newly created instruction.
+  SPIRVType *createOpType(MachineIRBuilder &MIRBuilder,
+                          std::function<MachineInstr *(MachineIRBuilder &)> Op);
+
 public:
   SPIRVGlobalRegistry(unsigned PointerSize);
 
@@ -336,6 +347,7 @@ class SPIRVGlobalRegistry {
   MachineFunction *setCurrentFunc(MachineFunction &MF) {
     MachineFunction *Ret = CurMF;
     CurMF = &MF;
+    LastInsertedType = nullptr;
     return Ret;
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index f1b10e264781f2..cd0aff1a518439 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -389,9 +389,7 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
   AssignTypeInst.getOperand(1).setReg(NewReg);
   MI.getOperand(0).setReg(NewReg);
-  MIB.setInsertPt(*MI.getParent(),
-                  (MI.getNextNode() ? MI.getNextNode()->getIterator()
-                                    : MI.getParent()->end()));
+  MIB.setInsertPt(*MI.getParent(), MI.getIterator());
   for (auto &Op : MI.operands()) {
     if (!Op.isReg() || Op.isDef())
       continue;

>From 62c2c8a5fd066ab0787cad39e8f7d4f417e3c297 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 24 Sep 2024 14:31:42 +0200
Subject: [PATCH 2/2] pr-feedback: keep state across context switch
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The code can switch functions multiple times, and come
back to a previously visited function, meaning we need to keep
the last inserted type for each function.

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 12 ++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  5 ++---
 2 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 509603fcd8bd55..3e1873e899680c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -152,16 +152,20 @@ SPIRVType *SPIRVGlobalRegistry::createOpType(
   auto oldInsertPoint = MIRBuilder.getInsertPt();
   MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
 
-  if (LastInsertedType == nullptr) {
+  auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
+  if (LastInsertedType != LastInsertedTypeMap.end()) {
     MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
-                           MIRBuilder.getMF().begin()->begin());
+                           LastInsertedType->second->getIterator());
   } else {
     MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
-                           LastInsertedType->getIterator());
+                           MIRBuilder.getMF().begin()->begin());
+    auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
+    assert(Result.second);
+    LastInsertedType = Result.first;
   }
 
   MachineInstr *Type = Op(MIRBuilder);
-  LastInsertedType = Type;
+  LastInsertedType->second = Type;
 
   MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
   return Type;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 4831e3437dda0d..cad2bf96adf33e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -64,9 +64,9 @@ class SPIRVGlobalRegistry {
   SmallPtrSet<const Type *, 4> TypesInProcessing;
   DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
 
-  // The last MIR inserted defining a SPIR-V Type.
+  // Stores for each function the last inserted SPIR-V Type.
   // See: SPIRVGlobalRegistry::createOpType.
-  MachineInstr *LastInsertedType = nullptr;
+  DenseMap<const MachineFunction *, MachineInstr *> LastInsertedTypeMap;
 
   // if a function returns a pointer, this is to map it into TypedPointerType
   DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
@@ -347,7 +347,6 @@ class SPIRVGlobalRegistry {
   MachineFunction *setCurrentFunc(MachineFunction &MF) {
     MachineFunction *Ret = CurMF;
     CurMF = &MF;
-    LastInsertedType = nullptr;
     return Ret;
   }
 



More information about the llvm-commits mailing list