[Mlir-commits] [mlir] [mlir][xegpu] SIMT distribution patterns for XeGPU CreateNdTdesc, LoadNd, StoreNd and Dpas Ops. (PR #135271)
Charitha Saumya
llvmlistbot at llvm.org
Tue Apr 29 11:07:58 PDT 2025
================
@@ -190,112 +222,119 @@ static SGMap getDefaultSgMap(VectorType vectorTy) {
"Expected int or float element type.");
/// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSgMap(1);
+ return getDefaultLayoutInfo(1);
/// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
- auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (bitwidth < packedSizeInBitsForDefault)
packingFactor = packedSizeInBitsForDefault / bitwidth;
- return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+ return LayoutInfo(LaneLayout({1, subgroupSize}),
+ LaneData({1, packingFactor}));
}
-/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
-/// set according to the following criteria:
+/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
+/// is set according to the following criteria:
/// * For A operand, the data must be packed in minimum
/// `packedSizeInBitsForDefault`
/// * For B operand, the data must be packed in minimum
/// `packedSizeInBitsForDpasB`
-static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
- auto elementTy = vectorTy.getElementType();
+static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
+ unsigned operandNum) {
+ Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
- WiLayout layout({1, subgroupSize});
+ LaneLayout layout({1, subgroupSize});
/// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
/// must have the VNNI format.
if (operandNum == 1 &&
elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
- WiData data(
+ LaneData data(
{packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
- return SGMap(layout, data);
+ return LayoutInfo(layout, data);
}
/// Otherwise, return the default layout for the vector type.
- return getDefaultSgMap(vectorTy);
+ return getDefaultLayoutInfo(vectorTy);
}
///===----------------------------------------------------------------------===///
-/// SGMapPropagation
+/// LayoutInfoPropagation
///===----------------------------------------------------------------------===///
-/// Backward data flow analysis to propagate the wi_layout and wi_data of each
-/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
-/// StoreScatter are fixed (known before propagation). Purpose of this analysis
-/// is to propagate those known layouts to all their producers and (other)
-/// consumers.
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+/// Backward data flow analysis to propagate the lane_layout and lane_data of
+/// each value in the program. Currently, the layouts for operands DPAS,
+/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
+/// this analysis is to propagate those known layouts to all their producers and
+/// (other) consumers.
+class LayoutInfoPropagation
+ : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
- void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
- void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ void visitStoreNdOp(xegpu::StoreNdOp store,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
- void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ void visitLoadNdOp(xegpu::LoadNdOp load,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitLoadGatherOp(xegpu::LoadGatherOp load,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitTransposeOp(vector::TransposeOp transpose,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitVectorBitcastOp(vector::BitCastOp bitcast,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitCreateDescOp(xegpu::CreateDescOp createDesc,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
public:
- SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+ LayoutInfoPropagation(DataFlowSolver &solver,
+ SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
- LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) override;
+ LogicalResult
+ visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) override;
void visitBranchOperand(OpOperand &operand) override {};
void visitCallOperand(OpOperand &operand) override {};
void visitExternalCall(CallOpInterface call,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) override {};
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) override {
+ };
- void setToExitState(SGMapLattice *lattice) override {
- (void)lattice->meet(SGMap());
+ void setToExitState(LayoutInfoLattice *lattice) override {
+ (void)lattice->meet(LayoutInfo());
}
};
} // namespace
-LogicalResult
-SGMapPropagation::visitOperation(Operation *op,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) {
+LogicalResult LayoutInfoPropagation::visitOperation(
+ Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
TypeSwitch<Operation *>(op)
.Case<xegpu::DpasOp>(
[&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
----------------
charithaintc wrote:
I will take a look at your code. I think we can try to incorporate those ideas as well in future.
https://github.com/llvm/llvm-project/pull/135271
More information about the Mlir-commits
mailing list