[Mlir-commits] [mlir] [mlir][xegpu] SIMT distribution patterns for XeGPU CreateNdTdesc, LoadNd, StoreNd and Dpas Ops. (PR #135271)

Charitha Saumya llvmlistbot at llvm.org
Tue Apr 29 11:07:58 PDT 2025


================
@@ -190,112 +222,119 @@ static SGMap getDefaultSgMap(VectorType vectorTy) {
          "Expected int or float element type.");
   /// If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSgMap(1);
+    return getDefaultLayoutInfo(1);
   /// Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
-  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   if (bitwidth < packedSizeInBitsForDefault)
     packingFactor = packedSizeInBitsForDefault / bitwidth;
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+  return LayoutInfo(LaneLayout({1, subgroupSize}),
+                    LaneData({1, packingFactor}));
 }
 
-/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
-/// set according to the following criteria:
+/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
+/// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum
 /// `packedSizeInBitsForDefault`
 /// * For B operand, the data must be packed in minimum
 /// `packedSizeInBitsForDpasB`
-static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
-  auto elementTy = vectorTy.getElementType();
+static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
+                                              unsigned operandNum) {
+  Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
-  WiLayout layout({1, subgroupSize});
+  LaneLayout layout({1, subgroupSize});
   /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
   /// must have the VNNI format.
   if (operandNum == 1 &&
       elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
-    WiData data(
+    LaneData data(
         {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
-    return SGMap(layout, data);
+    return LayoutInfo(layout, data);
   }
   /// Otherwise, return the default layout for the vector type.
-  return getDefaultSgMap(vectorTy);
+  return getDefaultLayoutInfo(vectorTy);
 }
 
 ///===----------------------------------------------------------------------===///
-/// SGMapPropagation
+/// LayoutInfoPropagation
 ///===----------------------------------------------------------------------===///
 
-/// Backward data flow analysis to propagate the wi_layout and wi_data of each
-/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
-/// StoreScatter are fixed (known before propagation). Purpose of this analysis
-/// is to propagate those known layouts to all their producers and (other)
-/// consumers.
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+/// Backward data flow analysis to propagate the lane_layout and lane_data of
+/// each value in the program. Currently, the layouts for operands DPAS,
+/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
+/// this analysis is to propagate those known layouts to all their producers and
+/// (other) consumers.
+class LayoutInfoPropagation
+    : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
-  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
-                   ArrayRef<const SGMapLattice *> results);
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+                   ArrayRef<const LayoutInfoLattice *> results);
 
-  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
-                      ArrayRef<const SGMapLattice *> results);
+  void visitStoreNdOp(xegpu::StoreNdOp store,
+                      ArrayRef<LayoutInfoLattice *> operands,
+                      ArrayRef<const LayoutInfoLattice *> results);
 
   void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
+                           ArrayRef<LayoutInfoLattice *> operands,
+                           ArrayRef<const LayoutInfoLattice *> results);
 
-  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
-                     ArrayRef<const SGMapLattice *> results);
+  void visitLoadNdOp(xegpu::LoadNdOp load,
+                     ArrayRef<LayoutInfoLattice *> operands,
+                     ArrayRef<const LayoutInfoLattice *> results);
 
   void visitLoadGatherOp(xegpu::LoadGatherOp load,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
 
   void visitTransposeOp(vector::TransposeOp transpose,
-                        ArrayRef<SGMapLattice *> operands,
-                        ArrayRef<const SGMapLattice *> results);
+                        ArrayRef<LayoutInfoLattice *> operands,
+                        ArrayRef<const LayoutInfoLattice *> results);
 
   void visitVectorBitcastOp(vector::BitCastOp bitcast,
-                            ArrayRef<SGMapLattice *> operands,
-                            ArrayRef<const SGMapLattice *> results);
+                            ArrayRef<LayoutInfoLattice *> operands,
+                            ArrayRef<const LayoutInfoLattice *> results);
 
   void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
 
   void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
-                             ArrayRef<SGMapLattice *> operands,
-                             ArrayRef<const SGMapLattice *> results);
+                             ArrayRef<LayoutInfoLattice *> operands,
+                             ArrayRef<const LayoutInfoLattice *> results);
 
   void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
-                                   ArrayRef<SGMapLattice *> operands,
-                                   ArrayRef<const SGMapLattice *> results);
+                                   ArrayRef<LayoutInfoLattice *> operands,
+                                   ArrayRef<const LayoutInfoLattice *> results);
 
 public:
-  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+  LayoutInfoPropagation(DataFlowSolver &solver,
+                        SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
-  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
-                               ArrayRef<const SGMapLattice *> results) override;
+  LogicalResult
+  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+                 ArrayRef<const LayoutInfoLattice *> results) override;
 
   void visitBranchOperand(OpOperand &operand) override {};
 
   void visitCallOperand(OpOperand &operand) override {};
 
   void visitExternalCall(CallOpInterface call,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results) override {};
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results) override {
+  };
 
-  void setToExitState(SGMapLattice *lattice) override {
-    (void)lattice->meet(SGMap());
+  void setToExitState(LayoutInfoLattice *lattice) override {
+    (void)lattice->meet(LayoutInfo());
   }
 };
 } // namespace
 
-LogicalResult
-SGMapPropagation::visitOperation(Operation *op,
-                                 ArrayRef<SGMapLattice *> operands,
-                                 ArrayRef<const SGMapLattice *> results) {
+LogicalResult LayoutInfoPropagation::visitOperation(
+    Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
   TypeSwitch<Operation *>(op)
       .Case<xegpu::DpasOp>(
           [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
----------------
charithaintc wrote:

I will take a look at your code. I think we can try to incorporate those ideas as well in future.  

https://github.com/llvm/llvm-project/pull/135271


More information about the Mlir-commits mailing list