[Mlir-commits] [mlir] [mlir][xegpu] Refine layout assignment in XeGPU SIMT distribution. (PR #142687)
Charitha Saumya
llvmlistbot at llvm.org
Mon Jun 9 16:29:56 PDT 2025
================
@@ -0,0 +1,955 @@
+//===- XeGPULayoutPropagate.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPULAYOUTPROPAGATE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-layout-propagate"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Layout
+//===----------------------------------------------------------------------===//
+
+/// Helper class to store the ND layout of lanes within a subgroup and data
+/// owned by each lane.
+struct Layout {
+ SmallVector<int64_t, 3> layout;
+ Layout() = default;
+ Layout(std::initializer_list<int64_t> list) : layout(list) {}
+ void print(llvm::raw_ostream &os) const;
+ size_t size() const { return layout.size(); }
+ int64_t operator[](size_t idx) const;
+};
+
+void Layout::print(llvm::raw_ostream &os) const {
+ os << llvm::interleaved_array(layout);
+}
+
+int64_t Layout::operator[](size_t idx) const {
+ assert(idx < layout.size() && "Index out of bounds.");
+ return layout[idx];
+}
+
+/// LaneLayout represents the logical layout of lanes within a subgroup when it
+/// accesses some value. LaneData represents the logical layout of data owned by
+/// each work item.
+using LaneLayout = Layout;
+using LaneData = Layout;
+
+//===----------------------------------------------------------------------===//
+// LayoutInfo
+//===----------------------------------------------------------------------===//
+
+/// Helper class for tracking the analysis state of an mlir value. For layout
+/// propagation, the analysis state is simply the lane_layout and lane_data of
+/// each value. Purpose of this analysis to propagate some unique layout for
+/// each value in the program starting from a set of anchor operations (like
+/// DPAS, StoreNd, etc.).
+///
+/// Given this, LayoutInfo satisifies the following properties:
+/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
+/// assigned`.
+/// 2) Two LayoutInfo values are equal if they are both assigned or
+/// both not assigned. The concrete value of assigned state does not matter.
+/// 3) The meet operator works as follows:
+/// - If current state is assigned, return the current state. (already
+/// a unique layout is assigned. don't change it)
+/// - Otherwise, return the other state.
+
+struct LayoutInfo {
+private:
+ LaneLayout laneLayout;
+ LaneData laneData;
+
+public:
+ LayoutInfo() = default;
+ LayoutInfo(const LaneLayout &layout, const LaneData &data)
+ : laneLayout(layout), laneData(data) {}
+
+ // Two lattice values are equal if they have `some` layout. The actual
+ // content of the layout does not matter.
+ bool operator==(const LayoutInfo &other) const {
+ return this->isAssigned() == other.isAssigned();
+ }
+
+ static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+ static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+ void print(raw_ostream &os) const;
+
+ bool isAssigned() const {
+ return laneLayout.size() > 0 && laneData.size() > 0;
+ }
+
+ LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+ const LaneLayout &getLayout() const { return laneLayout; }
+ const LaneData &getData() const { return laneData; }
+ ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
+ ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
+};
+
+void LayoutInfo::print(raw_ostream &os) const {
+ if (isAssigned()) {
+ os << "lane_layout: ";
+ laneLayout.print(os);
+ os << ", lane_data: ";
+ laneData.print(os);
+ } else {
+ os << "Not assigned.";
+ }
+}
+
+LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+ if (!lhs.isAssigned())
+ return rhs;
+ return lhs;
+}
+
+/// Since this is a backward analysis, join method is not used.
+LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+ llvm_unreachable("Join should not be triggered by layout propagation.");
+}
+
+/// Get the transposed layout according to the given permutation.
+LayoutInfo
+LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+ if (!isAssigned())
+ return {};
+ LaneLayout newLayout;
+ LaneData newData;
+ for (int64_t idx : permutation) {
+ newLayout.layout.push_back(laneLayout.layout[idx]);
+ newData.layout.push_back(laneData.layout[idx]);
+ }
+ return LayoutInfo(newLayout, newData);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoLattice
+//===----------------------------------------------------------------------===//
+
+/// Lattice holding the LayoutInfo for each value.
+struct LayoutInfoLattice : public Lattice<LayoutInfo> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
+ using Lattice::Lattice;
+};
+
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS).
+
+/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
+/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
+static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+ if (rank == 1)
+ return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
+ LaneData({1}));
+ return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+ LaneData({1, 1}));
+}
+
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
+ // Expecting a 1D or 2D vector.
+ assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+ "Expected 1D or 2D vector.");
+ // Expecting int or float element type.
+ assert(vectorTy.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
+ // If the rank is 1, then return default layout for 1D vector.
+ if (vectorTy.getRank() == 1)
+ return getDefaultLayoutInfo(1);
+ // Packing factor is determined by the element type bitwidth.
+ int packingFactor = 1;
+ unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+ packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
+ return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+ LaneData({1, packingFactor}));
+}
+
+/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
+/// is set according to the following criteria:
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
+static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
+ unsigned operandNum) {
+ Type elementTy = vectorTy.getElementType();
+ assert(elementTy.isIntOrFloat() &&
+ "Expected int or float type in DPAS operands");
+ LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
+ // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+ // must have the VNNI format.
+ if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
+ xegpu::targetinfo::packedSizeInBitsForDpasB) {
+ LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
+ elementTy.getIntOrFloatBitWidth(),
+ 1});
+ return LayoutInfo(layout, data);
+ }
+ // Otherwise, return the default layout for the vector type.
+ return getDefaultLayoutInfo(vectorTy);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoPropagation
+//===----------------------------------------------------------------------===//
+
+/// Backward data flow analysis to propagate the lane_layout and lane_data of
+/// each value in the program. Currently, the layouts for operands DPAS,
+/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
+/// this analysis is to propagate those known layouts to all their producers and
+/// (other) consumers.
+class LayoutInfoPropagation
+ : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
+private:
+ void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitStoreNdOp(xegpu::StoreNdOp store,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitLoadNdOp(xegpu::LoadNdOp load,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitLoadGatherOp(xegpu::LoadGatherOp load,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitTransposeOp(vector::TransposeOp transpose,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitVectorBitcastOp(vector::BitCastOp bitcast,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+public:
+ LayoutInfoPropagation(DataFlowSolver &solver,
+ SymbolTableCollection &symbolTable)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+ LogicalResult
+ visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) override;
+
+ void visitBranchOperand(OpOperand &operand) override {};
+
+ void visitCallOperand(OpOperand &operand) override {};
+
+ void visitExternalCall(CallOpInterface call,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) override {
+ };
+
+ void setToExitState(LayoutInfoLattice *lattice) override {
+ (void)lattice->meet(LayoutInfo());
+ }
+};
+} // namespace
+
+LogicalResult LayoutInfoPropagation::visitOperation(
+ Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ TypeSwitch<Operation *>(op)
+ .Case<xegpu::DpasOp>(
+ [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+ .Case<xegpu::StoreNdOp>(
+ [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+ .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+ visitStoreScatterOp(storeScatterOp, operands, results);
+ })
+ .Case<xegpu::LoadNdOp>(
+ [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+ .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+ visitLoadGatherOp(loadGatherOp, operands, results);
+ })
+ .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+ visitCreateDescOp(createDescOp, operands, results);
+ })
+ .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+ visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+ })
+ .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
+ visitPrefetchNdOp(prefetchNdOp, operands, results);
+ })
+ // No need to propagate the layout to operands in CreateNdDescOp because
+ // they are scalars (offsets, sizes, etc.).
+ .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
+ .Case<vector::TransposeOp>([&](auto transposeOp) {
+ visitTransposeOp(transposeOp, operands, results);
+ })
+ .Case<vector::BitCastOp>([&](auto bitcastOp) {
+ visitVectorBitcastOp(bitcastOp, operands, results);
+ })
+ .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+ visitVectorMultiReductionOp(reductionOp, operands, results);
+ })
+ // All other ops.
+ .Default([&](Operation *op) {
+ for (const LayoutInfoLattice *r : results) {
+ for (LayoutInfoLattice *operand : operands) {
+ // Propagate the layout of the result to the operand.
+ if (r->getValue().isAssigned())
+ meet(operand, *r);
+ }
+ }
+ });
+ // Add a dependency from each result to program point after the operation.
+ for (const LayoutInfoLattice *r : results) {
+ addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
+ }
+ return success();
+}
+
+void LayoutInfoPropagation::visitPrefetchNdOp(
+ xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+ auto prefetchLayout = getDefaultLayoutInfo(
+ VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+ // Propagate the layout to the source tensor descriptor.
+ propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
+}
+
+void LayoutInfoPropagation::visitVectorMultiReductionOp(
+ vector::MultiDimReductionOp reduction,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ // We only consider 2D -> 1D reductions at this point.
+ assert(resultLayout.getLayout().size() == 1 &&
----------------
charithaintc wrote:
converted to a return.
current assumption is that layout rank is same as vector rank. can you clarify on where this assumption will be broken? if there are such case we need to modify this analysis in multiple places, not just here.
https://github.com/llvm/llvm-project/pull/142687
More information about the Mlir-commits
mailing list