[llvm] [AMDGPU] Add TDM Descriptor Optimization Pass (PR #173324)

Quentin Colombet via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 22 17:56:57 PST 2025


================
@@ -0,0 +1,495 @@
+//===-- AMDGPUTDMOptimization.cpp - TDM Descriptor Optimization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes Tensor Data Movement (TDM) descriptor creation patterns.
+// It identifies insertelement chains that create descriptors and transforms them
+// to use alloca+field updates, which SROA later optimizes to INSERT_SUBREG.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "AMDGPUSubtarget.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-tdm-optimization"
+
+static cl::opt<unsigned>
+TDMOptBenefitThreshold("amdgpu-tdm-opt-threshold", cl::Hidden, cl::init(10),
+                       cl::desc("Minimum optimization benefit threshold for TDM descriptor optimization"));
+
+namespace llvm {
+  void initializeAMDGPUTDMOptimizationPass(PassRegistry &);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection Data Structures
+//===----------------------------------------------------------------------===//
+
+/// Represents a single descriptor creation pattern
+struct DescriptorPattern {
+  Type *DescType;                           ///< <4 x i32> or <8 x i32>
+  Value *BaseValue;                         ///< Base template (constant or computed)
+  SmallVector<InsertElementInst *, 8> Chain; ///< Chain of insertelement instructions
+  SmallVector<unsigned, 8> VariableFields;  ///< Fields that change
+  SmallVector<unsigned, 8> ConstantFields;  ///< Fields that stay constant
+  BasicBlock *Location;                      ///< Where the pattern is located
+  Loop *ContainingLoop;                      ///< Loop containing this pattern (if any)
+
+  /// Calculate field reuse ratio (constant fields / total fields)
+  float getFieldReuseRatio() const {
+    unsigned totalFields = cast<FixedVectorType>(DescType)->getNumElements();
+    return (float)ConstantFields.size() / totalFields;
+  }
+
+  /// Check if this pattern is worth optimizing
+  bool isWorthOptimizing() const {
+    // Always optimize if in loop with reuse potential
+    if (ContainingLoop && getFieldReuseRatio() >= 0.5f)
+      return true;
+
+    // Optimize if significant field reuse
+    if (getFieldReuseRatio() >= 0.75f)
+      return true;
+
+    // Optimize address descriptors (common case)
+    if (isAddressDescriptor() && ConstantFields.size() >= 1)
+      return true;
+
+    return false;
+  }
+
+  /// Check if this is an address descriptor (<4 x i32>)
+  bool isAddressDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 4 && VecTy->getElementType()->isIntegerTy(32);
+  }
+
+  /// Check if this is a tensor descriptor (<8 x i32>)
+  bool isTensorDescriptor() const {
+    auto *VecTy = cast<FixedVectorType>(DescType);
+    return VecTy->getNumElements() == 8 && VecTy->getElementType()->isIntegerTy(32);
+  }
+};
+
+/// Groups similar descriptor patterns for optimization
+struct DescriptorGroup {
+  SmallVector<DescriptorPattern, 4> Patterns;
+  Type *SharedType;
+  Value *SharedBase;  ///< Common base value (if any)
+  SmallVector<unsigned, 8> SharedConstantFields;
+
+  /// Calculate total optimization benefit
+  unsigned getOptimizationBenefit() const {
+    unsigned benefit = 0;
+    for (const auto &pattern : Patterns) {
+      // Base benefit from field reuse
+      benefit += pattern.ConstantFields.size() * 2;
+
+      // Extra benefit for loop patterns
+      if (pattern.ContainingLoop)
+        benefit *= 5;
+    }
+    return benefit;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AMDGPUTDMOptimization Pass
+//===----------------------------------------------------------------------===//
+
+class AMDGPUTDMOptimization : public FunctionPass {
+private:
+  LoopInfo *LI = nullptr;
+
+  /// Detected patterns in the function
+  SmallVector<DescriptorPattern, 16> DetectedPatterns;
+
+  /// Groups of optimizable patterns
+  SmallVector<DescriptorGroup, 8> OptimizationGroups;
+
+public:
+  static char ID;
+
+  AMDGPUTDMOptimization() : FunctionPass(ID) {
+    initializeAMDGPUTDMOptimizationPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+  /// Main optimization phases
+  bool detectDescriptorPatterns(Function &F);
+  void groupSimilarPatterns();
+  bool transformPatterns(Function &F);
+
+  /// Pattern detection helpers
+  bool isDescriptorType(Type *Ty) const;
+  DescriptorPattern analyzeInsertChain(InsertElementInst *FinalInsert);
+  Value *extractBaseValue(const DescriptorPattern &Pattern);
+
+  /// Transformation helpers
+  bool transformDescriptorGroup(DescriptorGroup &Group, Function &F);
+  Value *createSharedStorage(DescriptorGroup &Group, IRBuilder<> &Builder);
+  void transformSinglePattern(DescriptorPattern &Pattern, Value *SharedStorage,
+                             IRBuilder<> &Builder);
+
+  /// Utility functions
+  Loop *getContainingLoop(BasicBlock *BB);
+  bool arePatternsSimilar(const DescriptorPattern &A, const DescriptorPattern &B);
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::runOnFunction(Function &F) {
+  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+  LLVM_DEBUG(dbgs() << "Running TDM optimization on function: " << F.getName() << "\n");
+
+  // Phase 1: Detect descriptor patterns
+  if (!detectDescriptorPatterns(F)) {
+    LLVM_DEBUG(dbgs() << "No descriptor patterns found\n");
+    return false;
+  }
+
+  LLVM_DEBUG(dbgs() << "Found " << DetectedPatterns.size() << " descriptor patterns\n");
+
+  // Phase 2: Group similar patterns for optimization
+  groupSimilarPatterns();
+
+  LLVM_DEBUG(dbgs() << "Created " << OptimizationGroups.size() << " optimization groups\n");
+
+  // Phase 3: Transform patterns
+  bool Changed = transformPatterns(F);
+
+  // Cleanup for next function
+  DetectedPatterns.clear();
+  OptimizationGroups.clear();
+
+  return Changed;
+}
+
+void AMDGPUTDMOptimization::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.setPreservesCFG();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::detectDescriptorPatterns(Function &F) {
+  bool FoundPatterns = false;
+
+  // Scan function for insertelement instructions that create descriptors
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      auto *IE = dyn_cast<InsertElementInst>(&I);
+      if (!IE || !isDescriptorType(IE->getType()))
+        continue;
+
+      // Check if this is the final insert in a descriptor creation chain
+      if (!IE->hasOneUse() || isa<InsertElementInst>(*IE->user_begin()))
+        continue;
----------------
qcolombet wrote:

Is the one use limitation in line with what's produced by the frontend currently?

https://github.com/llvm/llvm-project/pull/173324


More information about the llvm-commits mailing list