[Mlir-commits] [mlir] [mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution. (PR #130240)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Mar 13 02:24:26 PDT 2025


================
@@ -0,0 +1,651 @@
+//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// HW dependent constants.
+/// TODO: These constants should be queried from the uArch interface.
+constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
+
+namespace {
+
+///===----------------------------------------------------------------------===///
+/// Layout
+///===----------------------------------------------------------------------===///
+
+/// Helper class to store the ND layout of work items within a subgroup and data
+/// owned by each work item.
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(const Layout &other) = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const;
+};
+
+void Layout::print(llvm::raw_ostream &os) const {
+  os << "[";
+  llvm::interleaveComma(layout, os);
+  os << "]";
+}
+
+int64_t Layout::operator[](size_t idx) const {
+  assert(idx < layout.size() && "Index out of bounds.");
+  return layout[idx];
+}
+
+/// WiLayout represents the layout of work items within a subgroup when it
+/// accesses some value. WiData represents the layout of data owned by each work
+/// item.
+using WiLayout = Layout;
+using WiData = Layout;
+
+///===----------------------------------------------------------------------===///
+/// SGMap
+///===----------------------------------------------------------------------===///
+
+/// Helper class for tracking the analysis state of a value. For SGPropagation,
+/// the analysis state is simply the wi_layout and wi_data of each value.
+/// Purpose of this analysis to propagate some unique layout for each value in
+/// the program starting from some known values (like DPAS, StoreNd, etc.).
+///
+/// Given this, SGMap satisifies the following properties:
+///  1) SGMap is a lattice with two states - assigned and not assigned.
+///  2) Two SGMap values are equal if they are both assigned or both not
+///  assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
+struct SGMap {
+private:
+  WiLayout wiLayout;
+  WiData wiData;
+
+public:
+  SGMap() = default;
+  SGMap(const SGMap &other) = default;
+  SGMap(const WiLayout &layout, const WiData &data)
+      : wiLayout(layout), wiData(data) {}
+
+  /// Two lattice values are equal if they have `some` layout. The actual
+  /// content of the layout does not matter.
+  bool operator==(const SGMap &other) const {
+    return this->isAssigned() == other.isAssigned();
+  }
+
+  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
+
+  static SGMap join(const SGMap &lhs, const SGMap &rhs);
+
+  void print(raw_ostream &os) const;
+
+  bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
+
+  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const WiLayout &getLayout() const { return wiLayout; }
+  const WiData &getData() const { return wiData; }
+};
+
+void SGMap::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "wi_layout: ";
+    wiLayout.print(os);
+    os << ", wi_data: ";
+    wiData.print(os);
+  } else
+    os << "Not assigned.";
+}
+
+SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
+
+/// Since this is a backward analysis, join method is not used.
+SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
+  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
+}
+
+/// Get the transposed layout according to the given permutation.
+SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  WiLayout newLayout;
+  WiData newData;
+  for (auto idx : permutation) {
+    newLayout.layout.push_back(wiLayout.layout[idx]);
+    newData.layout.push_back(wiData.layout[idx]);
+  }
+  return SGMap(newLayout, newData);
+}
+
+///===----------------------------------------------------------------------===///
+/// SGMapLattice
+///===----------------------------------------------------------------------===///
+
+/// Lattice holding the SGMap for each value.
+struct SGMapLattice : public Lattice<SGMap> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
+  using Lattice::Lattice;
+};
+
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS). This is the natural layout work items are arranged in a
+/// subgroup.
+
+/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
+/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
+static SGMap getDefaultSgMap(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1)
+    return SGMap(WiLayout({subgroupSize}), WiData({1}));
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+}
+
+/// Helper to get the default layout for a vector type.
+static SGMap getDefaultSgMap(VectorType vectorTy) {
+  /// Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  /// Expecting int or float element type.
+  assert(vectorTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  /// If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultSgMap(1);
+  /// Packing factor is determined by the element type bitwidth.
+  int packingFactor = 1;
+  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < packedSizeInBitsForDefault)
+    packingFactor = packedSizeInBitsForDefault / bitwidth;
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+}
+
+/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
+/// set according to the following criteria:
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
+static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
+  auto elementTy = vectorTy.getElementType();
+  assert(elementTy.isIntOrFloat() &&
+         "Expected int or float type in DPAS operands");
+  WiLayout layout({1, subgroupSize});
+  /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+  /// must have the VNNI format.
+  if (operandNum == 1 &&
+      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
+    WiData data(
+        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+    return SGMap(layout, data);
+  }
+  /// Otherwise, return the default layout for the vector type.
+  return getDefaultSgMap(vectorTy);
+}
+
+///===----------------------------------------------------------------------===///
+/// SGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Backward data flow analysis to propagate the wi_layout and wi_data of each
+/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
+/// StoreScatter are fixed (known before propagation). Purpose of this analysis
+/// is to propagate those known layouts to all their producers and (other)
+/// consumers.
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
+                   ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
+                      ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
+                     ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<SGMapLattice *> operands,
+                        ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<SGMapLattice *> operands,
+                            ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<SGMapLattice *> operands,
+                             ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results);
+
+public:
+  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
+                               ArrayRef<const SGMapLattice *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override {};
+
+  void visitCallOperand(OpOperand &operand) override {};
+
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results) override {};
+
+  void setToExitState(SGMapLattice *lattice) override {
+    (void)lattice->meet(SGMap());
+  }
+};
+} // namespace
+
+LogicalResult
+SGMapPropagation::visitOperation(Operation *op,
+                                 ArrayRef<SGMapLattice *> operands,
+                                 ArrayRef<const SGMapLattice *> results) {
+  TypeSwitch<Operation *>(op)
+      .Case<xegpu::DpasOp>(
+          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case<xegpu::StoreNdOp>(
+          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+        visitStoreScatterOp(storeScatterOp, operands, results);
+      })
+      .Case<xegpu::LoadNdOp>(
+          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+        visitLoadGatherOp(loadGatherOp, operands, results);
+      })
+      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+        visitCreateDescOp(createDescOp, operands, results);
+      })
+      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+      })
+      /// No need to propagate the layout to operands in CreateNdDescOp because
+      /// they are scalars (offsets, sizes, etc.).
+      .Case<xegpu::CreateNdDescOp>(
+          [&](auto createNdDescOp) { return success(); })
+      .Case<vector::TransposeOp>([&](auto transposeOp) {
+        visitTransposeOp(transposeOp, operands, results);
+      })
+      .Case<vector::BitCastOp>([&](auto bitcastOp) {
+        visitVectorBitcastOp(bitcastOp, operands, results);
+      })
+      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+        visitVectorMultiReductionOp(reductionOp, operands, results);
+      })
+      /// All other ops.
+      .Default([&](Operation *op) {
+        for (const SGMapLattice *r : results) {
+          for (SGMapLattice *operand : operands) {
+            /// Propagate the layout of the result to the operand.
+            if (r->getValue().isAssigned())
+              meet(operand, *r);
+          }
+        }
+      });
+  /// Add a dependency from each reult to program point after the operation.
+  /// NOTE: not sure if this is required, but all other similar analysis do
+  /// this.
+  for (const SGMapLattice *r : results) {
+    addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
+  }
+  return success();
+}
+
+void SGMapPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// We only consider 2D -> 1D reductions at this point.
+  assert(resultLayout.getLayout().size() == 1 &&
+         "Expected 1D layout for reduction result.");
+  /// Given that the result is 1D, the layout of the operand should be 2D with
+  /// default layout.
+  auto operandLayout = getDefaultSgMap(2);
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+  /// Accumulator should have the same layout as the result.
+  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void SGMapPropagation::visitUpdateNdOffsetOp(
+    xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+/// Set the layouts for DPAS A, B, and C operands.
+void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results) {
+  auto aTy = dpas.getLhsType();
+  auto bTy = dpas.getRhsType();
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+  if (operands.size() > 2) {
+    auto cTy = dpas.getAccType();
+    propagateIfChanged(operands[2],
+                       operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+  }
+};
+
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
+void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
+                                      ArrayRef<SGMapLattice *> operands,
+                                      ArrayRef<const SGMapLattice *> results) {
+  auto storeLayout = getDefaultSgMap(store.getValueType());
+  /// Both operands should have the same layout
+  for (SGMapLattice *operand : operands) {
+    propagateIfChanged(operand, operand->meet(storeLayout));
+  }
+}
+
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
+void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
+                                     ArrayRef<SGMapLattice *> operands,
+                                     ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  SGMap tensorDescLayout = valueLayout;
+  /// LoadNdOp has the transpose effect. However, at the stage of this analysis
+  /// this effect is not expected and should be abstracted away. Emit a warning.
+  if (auto transpose = load.getTranspose()) {
+    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+                     "SGMapPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  }
+  /// Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
+void SGMapPropagation::visitTransposeOp(
+    vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of transpose result to propagate to the operands.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation());
+  /// Propagate the new layout to the vector operand.
+  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
+
+/// For vector::BitCastOp, the wi_data of the source layout is changed based on
+/// the bit width of the source and result types.
+void SGMapPropagation::visitVectorBitcastOp(
+    vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of bitcast result to propagate to the operands.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  auto inElemTyBitWidth =
+      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+  auto outElemTyBitWidth =
+      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+  /// WiLayout does not change.
+  const WiLayout &newWiLayout = resultLayout.getLayout();
+  const WiData &currData = resultLayout.getData();
+  WiData newWiData;
+  /// It's a widening bitcast
+  if (inElemTyBitWidth < outElemTyBitWidth) {
+    auto ratio = outElemTyBitWidth / inElemTyBitWidth;
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] * ratio})
+                    : WiData({currData[0] * ratio, 1});
+  } else {
+    /// It's a narrowing bitcast
+    auto ratio = inElemTyBitWidth / outElemTyBitWidth;
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] / ratio})
+                    : WiData({currData[0] / ratio, 1});
+  }
+
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(SGMap(newWiLayout, newWiData)));
+}
+
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
+void SGMapPropagation::visitLoadGatherOp(
+    xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+
+  SGMap tensorDescLayout;
+  if (load.getTranspose()) {
+    /// LoadGatherOp has the transpose effect. However, at the stage of this
+    /// analyis this effect is not expected and should be abstracted away. Emit
+    /// a warning.
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+                     "SGMapPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  } else
+    tensorDescLayout = valueLayout;
----------------
adam-smnk wrote:

nit: braces

https://github.com/llvm/llvm-project/pull/130240


More information about the Mlir-commits mailing list