[Mlir-commits] [mlir] [mlir][xegpu] SIMT distribution patterns for XeGPU CreateNdTdesc, LoadNd, StoreNd and Dpas Ops. (PR #135271)

Charitha Saumya llvmlistbot at llvm.org
Wed Apr 30 10:09:16 PDT 2025


================
@@ -646,14 +1427,48 @@ struct XeGPUSubgroupDistributePass final
 };
 } // namespace
 
-void XeGPUSubgroupDistributePass::runOnOperation() {
-  Operation *op = getOperation();
-  RunSGMapPropagation solver(op);
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
+               LoadNdDistribution, DpasDistribution>(patterns.getContext());
+}
 
-  // Print the analysis result and exit.
+void XeGPUSubgroupDistributePass::runOnOperation() {
+  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
+  // Print the analysis result and exit. (for testing purposes)
   if (printOnly) {
     auto &os = llvm::outs();
-    solver.printAnalysisResult(os);
+    analyis.printAnalysisResult(os);
     return;
   }
+  auto getPropagatedLayout = [&](Value val) {
+    return analyis.getLayoutInfo(val);
+  };
+
+  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
+  // propagation analysis result.
+  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
+  if (failed(layoutAssignment.run())) {
+    signalPassFailure();
+    return;
+  }
+
+  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
+  // operation.
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+  // Finally, do the SIMD to SIMT distribution.
+  RewritePatternSet patterns(&getContext());
+  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // TODO: These are not used at this point.
+  auto distributionFn = [](Value val) { return AffineMap(); };
+  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
+                      int64_t warpSz) { return Value(); };
+  vector::populatePropagateWarpVectorDistributionPatterns(
+      patterns, distributionFn, shuffleFn);
+  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
----------------
charithaintc wrote:

added a check. 

https://github.com/llvm/llvm-project/pull/135271


More information about the Mlir-commits mailing list