[llvm] [AMDGPU] Add TDM Descriptor Optimization Pass (PR #173324)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 23 03:15:30 PST 2025


================
@@ -0,0 +1,515 @@
+//===-- AMDGPUTDMOptimization.cpp - TDM Descriptor Optimization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes Tensor Data Movement (TDM) descriptor creation patterns.
+// It identifies insertelement chains that create descriptors and transforms
+// them to use alloca+field updates, which SROA later optimizes to
+// INSERT_SUBREG.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "AMDGPUSubtarget.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-tdm-optimization"
+
+static cl::opt<unsigned>
+    TDMOptBenefitThreshold("amdgpu-tdm-opt-threshold", cl::Hidden, cl::init(10),
+                           cl::desc("Minimum optimization benefit threshold "
+                                    "for TDM descriptor optimization"));
+
+namespace llvm {
+void initializeAMDGPUTDMOptimizationPass(PassRegistry &);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection Data Structures
+//===----------------------------------------------------------------------===//
+
+/// Represents a single descriptor creation pattern
+struct DescriptorPattern {
+  Type *DescType;   ///< <4 x i32> or <8 x i32>
+  Value *BaseValue; ///< Base template (constant or computed)
+  SmallVector<InsertElementInst *, 8>
+      Chain; ///< Chain of insertelement instructions
+  SmallVector<unsigned, 8> VariableFields; ///< Fields that change
+  SmallVector<unsigned, 8> ConstantFields; ///< Fields that stay constant
+  BasicBlock *Location;                    ///< Where the pattern is located
+  Loop *ContainingLoop; ///< Loop containing this pattern (if any)
+
+  /// Calculate field reuse ratio (constant fields / total fields)
+  float getFieldReuseRatio() const {
+    unsigned totalFields = cast<FixedVectorType>(DescType)->getNumElements();
+    return (float)ConstantFields.size() / totalFields;
+  }
+
+  /// Check if this pattern is worth optimizing
+  bool isWorthOptimizing() const {
+    // Always optimize if in loop with reuse potential
+    if (ContainingLoop && getFieldReuseRatio() >= 0.5f)
+      return true;
+
+    // Optimize if significant field reuse
+    if (getFieldReuseRatio() >= 0.75f)
+      return true;
+
+    // Optimize address descriptors (common case)
+    if (isAddressDescriptor() && ConstantFields.size() >= 1)
+      return true;
+
+    return false;
+  }
+
+  /// Check if this is an address descriptor (<4 x i32>)
+  bool isAddressDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 4 &&
+           VecTy->getElementType()->isIntegerTy(32);
+  }
+
+  /// Check if this is a tensor descriptor (<8 x i32>)
+  bool isTensorDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 8 &&
+           VecTy->getElementType()->isIntegerTy(32);
+  }
+};
+
+/// Groups similar descriptor patterns for optimization
+struct DescriptorGroup {
+  SmallVector<DescriptorPattern, 4> Patterns;
+  Type *SharedType;
+  Value *SharedBase; ///< Common base value (if any)
+  SmallVector<unsigned, 8> SharedConstantFields;
+
+  /// Calculate total optimization benefit
+  unsigned getOptimizationBenefit() const {
+    unsigned benefit = 0;
+    for (const auto &pattern : Patterns) {
+      // Base benefit from field reuse
+      benefit += pattern.ConstantFields.size() * 2;
+
+      // Extra benefit for loop patterns
+      if (pattern.ContainingLoop)
+        benefit *= 5;
+    }
+    return benefit;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AMDGPUTDMOptimization Pass
+//===----------------------------------------------------------------------===//
+
+class AMDGPUTDMOptimization : public FunctionPass {
+private:
+  LoopInfo *LI = nullptr;
+
+  /// Detected patterns in the function
+  SmallVector<DescriptorPattern, 16> DetectedPatterns;
+
+  /// Groups of optimizable patterns
+  SmallVector<DescriptorGroup, 8> OptimizationGroups;
+
+public:
+  static char ID;
+
+  AMDGPUTDMOptimization() : FunctionPass(ID) {
+    initializeAMDGPUTDMOptimizationPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+  /// Main optimization phases
+  bool detectDescriptorPatterns(Function &F);
+  void groupSimilarPatterns();
+  bool transformPatterns(Function &F);
+
+  /// Pattern detection helpers
+  bool isDescriptorType(Type *Ty) const;
+  DescriptorPattern analyzeInsertChain(InsertElementInst *FinalInsert);
+  Value *extractBaseValue(const DescriptorPattern &Pattern);
+
+  /// Transformation helpers
+  bool transformDescriptorGroup(DescriptorGroup &Group, Function &F);
+  Value *createSharedStorage(DescriptorGroup &Group, IRBuilder<> &Builder);
+  void transformSinglePattern(DescriptorPattern &Pattern, Value *SharedStorage,
+                              IRBuilder<> &Builder);
+
+  /// Utility functions
+  Loop *getContainingLoop(BasicBlock *BB);
+  bool arePatternsSimilar(const DescriptorPattern &A,
+                          const DescriptorPattern &B);
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::runOnFunction(Function &F) {
+  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+  LLVM_DEBUG(dbgs() << "Running TDM optimization on function: " << F.getName()
+                    << "\n");
+
+  // Phase 1: Detect descriptor patterns
+  if (!detectDescriptorPatterns(F)) {
+    LLVM_DEBUG(dbgs() << "No descriptor patterns found\n");
+    return false;
+  }
+
+  LLVM_DEBUG(dbgs() << "Found " << DetectedPatterns.size()
+                    << " descriptor patterns\n");
+
+  // Phase 2: Group similar patterns for optimization
+  groupSimilarPatterns();
+
+  LLVM_DEBUG(dbgs() << "Created " << OptimizationGroups.size()
+                    << " optimization groups\n");
+
+  // Phase 3: Transform patterns
+  bool Changed = transformPatterns(F);
+
+  // Cleanup for next function
+  DetectedPatterns.clear();
+  OptimizationGroups.clear();
+
+  return Changed;
+}
+
+void AMDGPUTDMOptimization::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.setPreservesCFG();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::detectDescriptorPatterns(Function &F) {
+  bool FoundPatterns = false;
+
+  // Scan function for insertelement instructions that create descriptors
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      auto *IE = dyn_cast<InsertElementInst>(&I);
+      if (!IE || !isDescriptorType(IE->getType()))
+        continue;
+
+      // Check if this is the final insert in a descriptor creation chain
+      if (!IE->hasOneUse() || isa<InsertElementInst>(*IE->user_begin()))
+        continue;
+
+      // Analyze the complete chain
+      DescriptorPattern Pattern = analyzeInsertChain(IE);
+      if (Pattern.Chain.empty())
+        continue;
+
+      // Check if worth optimizing
+      if (!Pattern.isWorthOptimizing()) {
+        LLVM_DEBUG(
+            dbgs() << "Pattern not worth optimizing: field reuse ratio = "
+                   << Pattern.getFieldReuseRatio() << "\n");
+        continue;
+      }
+
+      LLVM_DEBUG(
+          dbgs() << "Found optimizable pattern: "
+                 << (Pattern.isAddressDescriptor() ? "Address" : "Tensor")
+                 << " descriptor with " << Pattern.ConstantFields.size()
+                 << " constant fields\n");
+
+      DetectedPatterns.push_back(std::move(Pattern));
+      FoundPatterns = true;
+    }
+  }
+
+  return FoundPatterns;
+}
+
+bool AMDGPUTDMOptimization::isDescriptorType(Type *Ty) const {
+  auto *VecTy = dyn_cast<FixedVectorType>(Ty);
+  if (!VecTy || !VecTy->getElementType()->isIntegerTy(32))
+    return false;
+
+  unsigned NumElements = VecTy->getNumElements();
+  return NumElements == 4 || NumElements == 8; // Address or tensor descriptors
+}
+
+DescriptorPattern
+AMDGPUTDMOptimization::analyzeInsertChain(InsertElementInst *FinalInsert) {
+  DescriptorPattern Pattern;
+  Pattern.DescType = FinalInsert->getType();
+  Pattern.Location = FinalInsert->getParent();
+  Pattern.ContainingLoop = getContainingLoop(Pattern.Location);
+
+  // Trace back the insertelement chain
+  SmallVector<InsertElementInst *, 8> Chain;
+  Value *CurrentVal = FinalInsert;
+
+  while (auto *IE = dyn_cast<InsertElementInst>(CurrentVal)) {
+    Chain.push_back(IE);
+    CurrentVal = IE->getOperand(0); // Vector being inserted into
+  }
+
+  // Reverse to get forward order
+  std::reverse(Chain.begin(), Chain.end());
+  Pattern.Chain = Chain;
+
+  // Extract base value (the initial vector)
+  Pattern.BaseValue = extractBaseValue(Pattern);
+
+  // Analyze which fields are constant vs variable
+  unsigned NumElements =
+      cast<FixedVectorType>(Pattern.DescType)->getNumElements();
+  SmallBitVector FieldSet(NumElements, false);
+
+  for (auto *IE : Chain) {
+    if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) {
+      unsigned Idx = CI->getZExtValue();
+      if (Idx < NumElements) {
+        FieldSet.set(Idx);
+        Pattern.VariableFields.push_back(Idx);
+      }
+    }
+  }
+
+  // Fields not in chain are constant
+  for (unsigned i = 0; i < NumElements; ++i) {
+    if (!FieldSet[i])
+      Pattern.ConstantFields.push_back(i);
+  }
+
+  return Pattern;
+}
+
+Value *
+AMDGPUTDMOptimization::extractBaseValue(const DescriptorPattern &Pattern) {
+  if (Pattern.Chain.empty())
+    return nullptr;
+
+  // Get the vector being inserted into by the first insert
+  Value *Base = Pattern.Chain[0]->getOperand(0);
+
+  // If base is a constant vector or another recognizable pattern, return it
+  if (isa<Constant>(Base))
+    return Base;
+
+  // For shufflevector results, we might want to trace further back
+  if (auto *SV = dyn_cast<ShuffleVectorInst>(Base))
+    return SV; // Keep shufflevector as base for now
+
+  return Base;
+}
+
+Loop *AMDGPUTDMOptimization::getContainingLoop(BasicBlock *BB) {
+  return LI ? LI->getLoopFor(BB) : nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Grouping
+//===----------------------------------------------------------------------===//
+
+void AMDGPUTDMOptimization::groupSimilarPatterns() {
+  // Simple grouping strategy: group by type and base similarity
+  for (auto &Pattern : DetectedPatterns) {
+    bool Added = false;
+
+    // Try to add to existing group
+    for (auto &Group : OptimizationGroups) {
+      if (Group.SharedType == Pattern.DescType &&
+          arePatternsSimilar(Group.Patterns[0], Pattern)) {
+        Group.Patterns.push_back(Pattern);
+        Added = true;
+        break;
+      }
+    }
+
+    // Create new group if needed
+    if (!Added) {
+      DescriptorGroup NewGroup;
+      NewGroup.SharedType = Pattern.DescType;
+      NewGroup.SharedBase = Pattern.BaseValue;
+      NewGroup.Patterns.push_back(Pattern);
+      OptimizationGroups.push_back(std::move(NewGroup));
+    }
+  }
+
+  // Remove groups that don't meet optimization criteria
+  OptimizationGroups.erase(
+      std::remove_if(OptimizationGroups.begin(), OptimizationGroups.end(),
+                     [](const DescriptorGroup &Group) {
+                       return Group.getOptimizationBenefit() <
+                              TDMOptBenefitThreshold;
+                     }),
+      OptimizationGroups.end());
+}
+
+bool AMDGPUTDMOptimization::arePatternsSimilar(const DescriptorPattern &A,
+                                               const DescriptorPattern &B) {
+  // Patterns are similar if they have same type and similar field usage
+  if (A.DescType != B.DescType)
+    return false;
+
+  // Check if constant fields overlap significantly
+  SmallBitVector AConstants(
+      cast<FixedVectorType>(A.DescType)->getNumElements());
+  SmallBitVector BConstants(
+      cast<FixedVectorType>(B.DescType)->getNumElements());
+
+  for (unsigned Field : A.ConstantFields)
+    AConstants.set(Field);
+  for (unsigned Field : B.ConstantFields)
+    BConstants.set(Field);
+
+  // Count overlapping constant fields
+  auto Intersection = AConstants & BConstants;
+  unsigned OverlapCount = Intersection.count();
+  unsigned TotalConstants = std::max(AConstants.count(), BConstants.count());
+
+  return TotalConstants > 0 && (float)OverlapCount / TotalConstants >= 0.5f;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Transformation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::transformPatterns(Function &F) {
+  bool Changed = false;
+
+  for (auto &Group : OptimizationGroups) {
+    LLVM_DEBUG(dbgs() << "Transforming group with " << Group.Patterns.size()
+                      << " patterns, benefit = "
+                      << Group.getOptimizationBenefit() << "\n");
+
+    if (transformDescriptorGroup(Group, F))
+      Changed = true;
+  }
+
+  return Changed;
+}
+
+bool AMDGPUTDMOptimization::transformDescriptorGroup(DescriptorGroup &Group,
+                                                     Function &F) {
+  if (Group.Patterns.empty())
+    return false;
+
+  // Find the best location to place shared storage
+  BasicBlock *StorageLocation = Group.Patterns[0].Location;
+
+  // If patterns are in a loop, try to hoist storage outside loop
+  if (auto *Loop = Group.Patterns[0].ContainingLoop) {
+    if (auto *Preheader = Loop->getLoopPreheader()) {
+      StorageLocation = Preheader;
+      LLVM_DEBUG(dbgs() << "Hoisting storage outside loop\n");
+    }
+  }
+
+  // Create shared storage at the beginning of the storage block
+  IRBuilder<> Builder(&StorageLocation->front());
+  Value *SharedStorage = createSharedStorage(Group, Builder);
+
+  if (!SharedStorage)
+    return false;
+
+  // Transform each pattern in the group
+  for (auto &Pattern : Group.Patterns) {
+    IRBuilder<> PatternBuilder(Pattern.Chain.back());
+    transformSinglePattern(Pattern, SharedStorage, PatternBuilder);
+  }
+
+  return true;
+}
+
+Value *AMDGPUTDMOptimization::createSharedStorage(DescriptorGroup &Group,
+                                                  IRBuilder<> &Builder) {
+  // Create alloca in address space 5 (AMDGPU private memory)
+  auto *StorageType = Group.SharedType;
+  auto *Storage = Builder.CreateAlloca(
+      StorageType, /*AddrSpace=*/5, /*ArraySize=*/nullptr, "tdm_desc_storage");
----------------
arsenm wrote:

And don't hardcode the address space, take from the datalayout (or at least use the enum) 

https://github.com/llvm/llvm-project/pull/173324


More information about the llvm-commits mailing list