[Mlir-commits] [mlir] [mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution. (PR #130240)

Charitha Saumya llvmlistbot at llvm.org
Wed Mar 12 12:59:03 PDT 2025


================
@@ -0,0 +1,647 @@
+//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-subgroup-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// HW dependent constants.
+/// TODO: These constants should be queried from the uArch interface.
+constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
+
+namespace {
+
+///===----------------------------------------------------------------------===///
+/// Layout
+///===----------------------------------------------------------------------===///
+
+/// Helper class to store the ND layout of work items within a subgroup and data
+/// owned by each work item.
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(const Layout &other) = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const;
+};
+
+void Layout::print(llvm::raw_ostream &os) const {
+  os << "[";
+  llvm::interleaveComma(layout, os);
+  os << "]";
+}
+
+int64_t Layout::operator[](size_t idx) const {
+  assert(idx < layout.size() && "Index out of bounds.");
+  return layout[idx];
+}
+
+/// WiLayout represents the layout of work items within a subgroup when it
+/// accesses some value. WiData represents the layout of data owned by each work
+/// item.
+using WiLayout = Layout;
+using WiData = Layout;
+
+///===----------------------------------------------------------------------===///
+/// SGMap
+///===----------------------------------------------------------------------===///
+
+/// Helper class for tracking the analysis state of a value. For SGPropagation,
+/// the analysis state is simply the wi_layout and wi_data of each value.
+/// Purpose of this analysis to propagate some unique layout for each value in
+/// the program starting from some known values (like DPAS, StoreNd, etc.).
+///
+/// Given this, SGMap satisifies the following properties:
+///  1) SGMap is a lattice with two states - assigned and not assigned.
+///  2) Two SGMap values are equal if they are both assigned or both not
+///  assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
+struct SGMap {
+private:
+  WiLayout wiLayout;
+  WiData wiData;
+
+public:
+  SGMap() = default;
+  SGMap(const SGMap &other) = default;
+  SGMap(const WiLayout &layout, const WiData &data)
+      : wiLayout(layout), wiData(data) {}
+
+  /// Two lattice values are equal if they have `some` layout. The actual
+  /// content of the layout does not matter.
+  bool operator==(const SGMap &other) const {
+    return this->isAssigned() == other.isAssigned();
+  }
+
+  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
+
+  static SGMap join(const SGMap &lhs, const SGMap &rhs);
+
+  void print(raw_ostream &os) const;
+
+  bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
+
+  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const WiLayout &getLayout() const { return wiLayout; }
+  const WiData &getData() const { return wiData; }
+};
+
+void SGMap::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "wi_layout: ";
+    wiLayout.print(os);
+    os << ", wi_data: ";
+    wiData.print(os);
+  } else
+    os << "Not assigned.";
+}
+
+SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
+
+/// Since this is a backward analysis, join method is not used.
+SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
+  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
+}
+
+/// Get the transposed layout according to the given permutation.
+SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  WiLayout newLayout;
+  WiData newData;
+  for (auto idx : permutation) {
+    newLayout.layout.push_back(wiLayout.layout[idx]);
+    newData.layout.push_back(wiData.layout[idx]);
+  }
+  return SGMap(newLayout, newData);
+}
+
+///===----------------------------------------------------------------------===///
+/// SGMapLattice
+///===----------------------------------------------------------------------===///
+
+/// Lattice holding the SGMap for each value.
+struct SGMapLattice : public Lattice<SGMap> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
+  using Lattice::Lattice;
+};
+
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS). This is the natural layout work items are arranged in a
+/// subgroup.
+
+/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
+/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
+static SGMap getDefaultSgMap(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1)
+    return SGMap(WiLayout({subgroupSize}), WiData({1}));
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+}
+
+/// Helper to get the default layout for a vector type.
+static SGMap getDefaultSgMap(VectorType vectorTy) {
+  /// Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  /// Expecting int or float element type.
+  assert(vectorTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  /// If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultSgMap(1);
+  /// Packing factor is determined by the element type bitwidth.
+  int packingFactor = 1;
+  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < packedSizeInBitsForDefault)
+    packingFactor = packedSizeInBitsForDefault / bitwidth;
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+}
+
+/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
+/// set according to the following criteria:
+/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
+/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
+static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
+  auto elementTy = vectorTy.getElementType();
+  assert(elementTy.isIntOrFloat() &&
+         "Expected int or float type in DPAS operands");
+  WiLayout layout({1, subgroupSize});
+  /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+  /// must have the VNNI format.
+  if (operandNum == 1 &&
+      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
+    WiData data(
+        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+    return SGMap(layout, data);
+  }
+  /// Otherwise, return the default layout for the vector type.
+  return getDefaultSgMap(vectorTy);
+}
+
+///===----------------------------------------------------------------------===///
+/// SGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Backward data flow analysis to propagate the wi_layout and wi_data of each
+/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
+/// StoreScatter are fixed (known before propagation). Purpose of this analysis
+/// is to propagate those known layouts to all their producers and (other)
+/// consumers.
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
+                   ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
+                      ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
+                     ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<SGMapLattice *> operands,
+                        ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<SGMapLattice *> operands,
+                            ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<SGMapLattice *> operands,
+                             ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results);
+
+public:
+  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
+                               ArrayRef<const SGMapLattice *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override {};
+
+  void visitCallOperand(OpOperand &operand) override {};
+
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results) override {};
+
+  void setToExitState(SGMapLattice *lattice) override {
+    (void)lattice->meet(SGMap());
+  }
+};
+} // namespace
+
+LogicalResult
+SGMapPropagation::visitOperation(Operation *op,
+                                 ArrayRef<SGMapLattice *> operands,
+                                 ArrayRef<const SGMapLattice *> results) {
+  if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
+    visitDpasOp(dpas, operands, results);
+  else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
+    visitStoreNdOp(store, operands, results);
+  else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
+    visitLoadNdOp(load, operands, results);
+  else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
+    visitTransposeOp(transpose, operands, results);
+  else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
+    visitVectorBitcastOp(bitcast, operands, results);
+  else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
+    visitLoadGatherOp(loadGather, operands, results);
+  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+    visitCreateDescOp(createDesc, operands, results);
+  else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
+    visitStoreScatterOp(storeScatter, operands, results);
+  else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+    visitUpdateNdOffsetOp(updateNdOffset, operands, results);
+  else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
+    visitVectorMultiReductionOp(reduction, operands, results);
+  /// No need to propagate the layout to operands in CreateNdDescOp because they
+  /// are scalars (offsets, sizes, etc.).
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    return success();
+  /// All other ops
+  else {
+    for (const SGMapLattice *r : results) {
+      for (SGMapLattice *operand : operands) {
+        /// Propagate the layout of the result to the operand.
+        if (r->getValue().isAssigned())
+          meet(operand, *r);
+      }
+    }
+  }
+  /// Add a dependency from each reult to program point after the operation.
+  /// NOTE: not sure if this is required, but all other similar analysis do
+  /// this.
+  for (const SGMapLattice *r : results) {
+    addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
+  }
+  return success();
+}
+
+void SGMapPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// We only consider 2D -> 1D reductions at this point.
+  assert(resultLayout.getLayout().size() == 1 &&
----------------
charithaintc wrote:

Good point. If it is a scalar, it is up to the consumer of the scalar to decide the layout for the result of the multi reduction.

condition here is that if the result layout is set (by someone) it needs to be a 1D layout. Even if the result is scalar I don't expect the layout will be 0-ranked. So this check is there as a sanity check. 

Overall, our current use cases don't require reducing to scalar at SG level. 

https://github.com/llvm/llvm-project/pull/130240


More information about the Mlir-commits mailing list