[llvm] [AMDGPU] Add TDM Descriptor Optimization Pass (PR #173324)

Juan Manuel Martinez CaamaƱo via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 23 01:03:22 PST 2025


================
@@ -0,0 +1,515 @@
+//===-- AMDGPUTDMOptimization.cpp - TDM Descriptor Optimization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes Tensor Data Movement (TDM) descriptor creation patterns.
+// It identifies insertelement chains that create descriptors and transforms
+// them to use alloca+field updates, which SROA later optimizes to
+// INSERT_SUBREG.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "AMDGPUSubtarget.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-tdm-optimization"
+
+static cl::opt<unsigned>
+    TDMOptBenefitThreshold("amdgpu-tdm-opt-threshold", cl::Hidden, cl::init(10),
+                           cl::desc("Minimum optimization benefit threshold "
+                                    "for TDM descriptor optimization"));
+
+namespace llvm {
+void initializeAMDGPUTDMOptimizationPass(PassRegistry &);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection Data Structures
+//===----------------------------------------------------------------------===//
+
+/// Represents a single descriptor creation pattern
+struct DescriptorPattern {
+  Type *DescType;   ///< <4 x i32> or <8 x i32>
+  Value *BaseValue; ///< Base template (constant or computed)
+  SmallVector<InsertElementInst *, 8>
+      Chain; ///< Chain of insertelement instructions
+  SmallVector<unsigned, 8> VariableFields; ///< Fields that change
+  SmallVector<unsigned, 8> ConstantFields; ///< Fields that stay constant
+  BasicBlock *Location;                    ///< Where the pattern is located
+  Loop *ContainingLoop; ///< Loop containing this pattern (if any)
+
+  /// Calculate field reuse ratio (constant fields / total fields)
+  float getFieldReuseRatio() const {
+    unsigned totalFields = cast<FixedVectorType>(DescType)->getNumElements();
+    return (float)ConstantFields.size() / totalFields;
+  }
+
+  /// Check if this pattern is worth optimizing
+  bool isWorthOptimizing() const {
+    // Always optimize if in loop with reuse potential
+    if (ContainingLoop && getFieldReuseRatio() >= 0.5f)
+      return true;
+
+    // Optimize if significant field reuse
+    if (getFieldReuseRatio() >= 0.75f)
+      return true;
+
+    // Optimize address descriptors (common case)
+    if (isAddressDescriptor() && ConstantFields.size() >= 1)
+      return true;
+
+    return false;
+  }
+
+  /// Check if this is an address descriptor (<4 x i32>)
+  bool isAddressDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 4 &&
+           VecTy->getElementType()->isIntegerTy(32);
+  }
+
+  /// Check if this is a tensor descriptor (<8 x i32>)
+  bool isTensorDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 8 &&
+           VecTy->getElementType()->isIntegerTy(32);
+  }
+};
+
+/// Groups similar descriptor patterns for optimization
+struct DescriptorGroup {
+  SmallVector<DescriptorPattern, 4> Patterns;
+  Type *SharedType;
+  Value *SharedBase; ///< Common base value (if any)
+  SmallVector<unsigned, 8> SharedConstantFields;
+
+  /// Calculate total optimization benefit
+  unsigned getOptimizationBenefit() const {
+    unsigned benefit = 0;
+    for (const auto &pattern : Patterns) {
+      // Base benefit from field reuse
+      benefit += pattern.ConstantFields.size() * 2;
+
+      // Extra benefit for loop patterns
+      if (pattern.ContainingLoop)
+        benefit *= 5;
+    }
+    return benefit;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AMDGPUTDMOptimization Pass
+//===----------------------------------------------------------------------===//
+
+class AMDGPUTDMOptimization : public FunctionPass {
+private:
+  LoopInfo *LI = nullptr;
+
+  /// Detected patterns in the function
+  SmallVector<DescriptorPattern, 16> DetectedPatterns;
+
+  /// Groups of optimizable patterns
+  SmallVector<DescriptorGroup, 8> OptimizationGroups;
+
+public:
+  static char ID;
+
+  AMDGPUTDMOptimization() : FunctionPass(ID) {
+    initializeAMDGPUTDMOptimizationPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+  /// Main optimization phases
+  bool detectDescriptorPatterns(Function &F);
+  void groupSimilarPatterns();
+  bool transformPatterns(Function &F);
+
+  /// Pattern detection helpers
+  bool isDescriptorType(Type *Ty) const;
+  DescriptorPattern analyzeInsertChain(InsertElementInst *FinalInsert);
+  Value *extractBaseValue(const DescriptorPattern &Pattern);
+
+  /// Transformation helpers
+  bool transformDescriptorGroup(DescriptorGroup &Group, Function &F);
+  Value *createSharedStorage(DescriptorGroup &Group, IRBuilder<> &Builder);
+  void transformSinglePattern(DescriptorPattern &Pattern, Value *SharedStorage,
+                              IRBuilder<> &Builder);
+
+  /// Utility functions
+  Loop *getContainingLoop(BasicBlock *BB);
+  bool arePatternsSimilar(const DescriptorPattern &A,
+                          const DescriptorPattern &B);
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::runOnFunction(Function &F) {
+  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+  LLVM_DEBUG(dbgs() << "Running TDM optimization on function: " << F.getName()
+                    << "\n");
+
+  // Phase 1: Detect descriptor patterns
+  if (!detectDescriptorPatterns(F)) {
+    LLVM_DEBUG(dbgs() << "No descriptor patterns found\n");
+    return false;
+  }
+
+  LLVM_DEBUG(dbgs() << "Found " << DetectedPatterns.size()
+                    << " descriptor patterns\n");
+
+  // Phase 2: Group similar patterns for optimization
+  groupSimilarPatterns();
+
+  LLVM_DEBUG(dbgs() << "Created " << OptimizationGroups.size()
+                    << " optimization groups\n");
+
+  // Phase 3: Transform patterns
+  bool Changed = transformPatterns(F);
+
+  // Cleanup for next function
+  DetectedPatterns.clear();
+  OptimizationGroups.clear();
+
+  return Changed;
+}
+
+void AMDGPUTDMOptimization::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.setPreservesCFG();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::detectDescriptorPatterns(Function &F) {
+  bool FoundPatterns = false;
+
+  // Scan function for insertelement instructions that create descriptors
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      auto *IE = dyn_cast<InsertElementInst>(&I);
+      if (!IE || !isDescriptorType(IE->getType()))
+        continue;
+
+      // Check if this is the final insert in a descriptor creation chain
+      if (!IE->hasOneUse() || isa<InsertElementInst>(*IE->user_begin()))
+        continue;
+
+      // Analyze the complete chain
+      DescriptorPattern Pattern = analyzeInsertChain(IE);
+      if (Pattern.Chain.empty())
+        continue;
+
+      // Check if worth optimizing
+      if (!Pattern.isWorthOptimizing()) {
+        LLVM_DEBUG(
+            dbgs() << "Pattern not worth optimizing: field reuse ratio = "
+                   << Pattern.getFieldReuseRatio() << "\n");
+        continue;
+      }
+
+      LLVM_DEBUG(
+          dbgs() << "Found optimizable pattern: "
+                 << (Pattern.isAddressDescriptor() ? "Address" : "Tensor")
+                 << " descriptor with " << Pattern.ConstantFields.size()
+                 << " constant fields\n");
+
+      DetectedPatterns.push_back(std::move(Pattern));
+      FoundPatterns = true;
+    }
+  }
+
+  return FoundPatterns;
+}
+
+bool AMDGPUTDMOptimization::isDescriptorType(Type *Ty) const {
+  auto *VecTy = dyn_cast<FixedVectorType>(Ty);
+  if (!VecTy || !VecTy->getElementType()->isIntegerTy(32))
+    return false;
+
+  unsigned NumElements = VecTy->getNumElements();
+  return NumElements == 4 || NumElements == 8; // Address or tensor descriptors
+}
+
+DescriptorPattern
+AMDGPUTDMOptimization::analyzeInsertChain(InsertElementInst *FinalInsert) {
+  DescriptorPattern Pattern;
+  Pattern.DescType = FinalInsert->getType();
+  Pattern.Location = FinalInsert->getParent();
+  Pattern.ContainingLoop = getContainingLoop(Pattern.Location);
+
+  // Trace back the insertelement chain
+  SmallVector<InsertElementInst *, 8> Chain;
+  Value *CurrentVal = FinalInsert;
+
+  while (auto *IE = dyn_cast<InsertElementInst>(CurrentVal)) {
+    Chain.push_back(IE);
+    CurrentVal = IE->getOperand(0); // Vector being inserted into
+  }
+
+  // Reverse to get forward order
+  std::reverse(Chain.begin(), Chain.end());
+  Pattern.Chain = Chain;
+
+  // Extract base value (the initial vector)
+  Pattern.BaseValue = extractBaseValue(Pattern);
+
+  // Analyze which fields are constant vs variable
+  unsigned NumElements =
+      cast<FixedVectorType>(Pattern.DescType)->getNumElements();
+  SmallBitVector FieldSet(NumElements, false);
+
+  for (auto *IE : Chain) {
+    if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) {
+      unsigned Idx = CI->getZExtValue();
+      if (Idx < NumElements) {
+        FieldSet.set(Idx);
+        Pattern.VariableFields.push_back(Idx);
+      }
+    }
+  }
+
+  // Fields not in chain are constant
+  for (unsigned i = 0; i < NumElements; ++i) {
+    if (!FieldSet[i])
+      Pattern.ConstantFields.push_back(i);
+  }
+
+  return Pattern;
+}
+
+Value *
+AMDGPUTDMOptimization::extractBaseValue(const DescriptorPattern &Pattern) {
+  if (Pattern.Chain.empty())
+    return nullptr;
+
+  // Get the vector being inserted into by the first insert
+  Value *Base = Pattern.Chain[0]->getOperand(0);
+
+  // If base is a constant vector or another recognizable pattern, return it
+  if (isa<Constant>(Base))
+    return Base;
+
+  // For shufflevector results, we might want to trace further back
+  if (auto *SV = dyn_cast<ShuffleVectorInst>(Base))
+    return SV; // Keep shufflevector as base for now
+
+  return Base;
+}
+
+Loop *AMDGPUTDMOptimization::getContainingLoop(BasicBlock *BB) {
+  return LI ? LI->getLoopFor(BB) : nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Grouping
+//===----------------------------------------------------------------------===//
----------------
jmmartinez wrote:

This is a very subjective take, but I find problematic the use of comments for structuring code.

I find that carefully choosing the name of the functions better (which you did actually).

Somebody will add code below, that is not related to "pattern grouping" and the comment will add to the confusion.

https://github.com/llvm/llvm-project/pull/173324


More information about the llvm-commits mailing list