[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)

Chao Chen llvmlistbot at llvm.org
Thu Aug 21 07:36:33 PDT 2025


================
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  SmallVector<OpFoldResult> reversedResult;
+  auto l = lhs.rbegin();
+  auto r = rhs.rbegin();
+  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+    if (l == lhs.rend()) {
+      reversedResult.push_back(*r);
+    } else if (r == rhs.rend()) {
+      reversedResult.push_back(*l);
+    } else {
+      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+      reversedResult.push_back(add);
+    }
+  }
+  return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+              xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
----------------
chencha3 wrote:

Agree. I refactored the related code. But I still think we need a systematic approach to verify the validness of attributes and IR. 

https://github.com/llvm/llvm-project/pull/154403


More information about the Mlir-commits mailing list