[Mlir-commits] [mlir] [mlir][xegpu] Enable support for ConvertLayoutOp (PR #146176)

Chao Chen llvmlistbot at llvm.org
Mon Jun 30 08:07:47 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/146176

>From 2e0f4dbcb5c3635904e6200cbe763b683e4e3f21 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 24 Jun 2025 19:57:09 +0000
Subject: [PATCH 1/5] update convert layout definition

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 21 +++----
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  4 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 60 +++++++++++++------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  4 ++
 mlir/test/Dialect/XeGPU/invalid.mlir          | 14 +----
 mlir/test/Dialect/XeGPU/layout.mlir           |  8 +--
 6 files changed, 67 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index daab65ec893b8..97887cef684df 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -918,21 +918,22 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
 def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["source", "result"]>]> {
     let summary = "Convert the layout of the input operand";
     let description = [{
-      `convert_layout` adjusts the data distribution across subgroups and/or work-items by modifying
-      the `LayoutAttr`. Both `srcMap` and `resMap` must correspond to the same programming scope, such
-      as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once the IR is
-      lowered to WI level because that is the end result of all distributions.
+      `convert_layout` redistribute data across subgroups and/or work-items from the `input_layout` to
+      the `target_layout`. Both `input_layout` and `target_layout` must correspond to the same programming
+      scope, such as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once
+      the IR is lowered to WI level because that is the end result of all distributions.
     }];
-    let arguments = (ins XeGPU_Vector2DType: $source,
-                         XeGPU_LayoutAttr: $srcMap,
-                         XeGPU_LayoutAttr: $resMap
-                         );
-    let results = (outs XeGPU_Vector2DType: $result);
+    let arguments = (ins XeGPU_VectorType: $source,
+                         XeGPU_LayoutAttr: $input_layout,
+                         XeGPU_LayoutAttr: $target_layout);
+    let results = (outs XeGPU_VectorType: $result);
     let assemblyFormat = [{
-        $source attr-dict `:` type($source)
+        $source prop-dict attr-dict `:` type($source)
     }];
 
+    let hasFolder = 1;
     let hasVerifier = 1;
+    let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84314875c2ae5..af40b3754bd8a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
 def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
 def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
-def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
-def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
+def XeGPU_ValueType: AnyTypeOf<[XeGPU_VectorType, XeGPU_ScalarType]>;
 
 // common base class for types in XeGPU dialect
 class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..10ce019d5a812 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -609,32 +609,58 @@ LogicalResult DpasOp::verify() {
 // XeGPU_ConvertLayoutOp
 //===----------------------------------------------------------------------===//
 LogicalResult ConvertLayoutOp::verify() {
-  auto srcMap = getSrcMapAttr();
-  auto resMap = getResMapAttr();
-  if (!srcMap)
-    return emitOpError("expected srcMap.");
-  if (!resMap)
-    return emitOpError("expected resMap.");
-
-  if (srcMap == resMap)
-    return emitOpError("expected different srcMap and resMap.");
+  auto srcLayout = getInputLayoutAttr();
+  auto resLayout = getTargetLayoutAttr();
+  if (!srcLayout)
+    return emitOpError("expected input layout.");
+  if (!resLayout)
+    return emitOpError("expected target layout.");
 
   // both srcMap and resMap should be WgLayout or SgLayout at the same time.
-  if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
-      (!srcMap.isSgLayout() || !resMap.isSgLayout()))
-    return emitOpError(
-        "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
+  if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
+      (!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
+    return emitOpError("expected input layout and target layout be WgLayout or "
+                       "SgLayout at the same time.");
 
   auto shape = getSource().getType().getShape();
-  if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
-    return emitOpError("invalid srcMap, data cannot be evenly distributed.");
+  if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+    return emitOpError(
+        "invalid input layout, data cannot be evenly distributed.");
 
-  if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
-    return emitOpError("invalid resMap, data cannot be evenly distributed.");
+  if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+    return emitOpError(
+        "invalid target layout, data cannot be evenly distributed.");
 
   return mlir::success();
 }
 
+OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
+  llvm::dbgs() << "\nSource from adaptor: " << adaptor.getSource() << "\n";
+  auto srcLayout = getInputLayoutAttr();
+  auto resLayout = getTargetLayoutAttr();
+  if (srcLayout == resLayout)
+    return adaptor.getSource();
+  return {};
+}
+
+struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
+  using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
+                                PatternRewriter &rewriter) const override {
+    auto inputLayout = op.getInputLayoutAttr();
+    auto targetLayout = op.getTargetLayoutAttr();
+    if (inputLayout != targetLayout)
+      return failure();
+    rewriter.replaceOp(op, op.getSource());
+    return success();
+  }
+};
+
+void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<FoldConvertLayoutOp>(context);
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b85a66a8bd36..aa1755e25996a 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -124,6 +124,10 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
     Operation *defOp = result.getDefiningOp();
     assert(defOp && "result must have a defining op");
 
+    // For ConvertLayoutOp, the layout is stored in the tensor descriptor
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
+      return convertOp.getTargetLayoutAttr();
+
     // for LoadNdOp, the layout is stored in the tensor descriptor
     if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
       return getLayoutAttr(loadNd.getTensorDesc());
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a2778cd94d963..65e1d22449bdd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -511,19 +511,11 @@ func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vecto
   return
 }
 
-// -----
-func.func @convert_layout_same_map(%a: vector<32x64xf16>) {
-  // expected-error at +1 {{expected different srcMap and resMap}}
-  %2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-                                resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
-  gpu.return
-}
-
 // -----
 func.func @convert_layout_unmatch(%a: vector<32x64xf16>) {
-  // expected-error at +1 {{expected srcMap and resMap be WgLayout or SgLayout at the same time}}
-  %2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
-                                resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
+  // expected-error at +1 {{expected input layout and target layout be WgLayout or SgLayout at the same time}}
+  %2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
+                                target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
   gpu.return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/layout.mlir b/mlir/test/Dialect/XeGPU/layout.mlir
index 7f3ebec225cdf..ef51dfbbfd574 100644
--- a/mlir/test/Dialect/XeGPU/layout.mlir
+++ b/mlir/test/Dialect/XeGPU/layout.mlir
@@ -35,14 +35,14 @@ gpu.func @create_nd_tdesc_wg_1(%src: memref<24x32xf32>) {
 }
 
 gpu.func @convert_layout(%a: vector<32x64xf16>) {
-  %2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-                                resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
+  %2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+                                target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
   gpu.return
 }
 
 gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
-  %2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
-                                resMap = #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 32], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
+  %2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
+                                target_layout = #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 32], lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
   gpu.return
 }
 

>From 9e89e7279a56816b54f5eb5ce1fc9ed3fcde0576 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 24 Jun 2025 21:16:38 +0000
Subject: [PATCH 2/5] add convert layout blocking pattern

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp              | 12 ++++--------
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 11 +++++++++++
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 10ce019d5a812..54b1e360d11f1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -609,8 +609,8 @@ LogicalResult DpasOp::verify() {
 // XeGPU_ConvertLayoutOp
 //===----------------------------------------------------------------------===//
 LogicalResult ConvertLayoutOp::verify() {
-  auto srcLayout = getInputLayoutAttr();
-  auto resLayout = getTargetLayoutAttr();
+  auto srcLayout = getInputLayout();
+  auto resLayout = getTargetLayout();
   if (!srcLayout)
     return emitOpError("expected input layout.");
   if (!resLayout)
@@ -636,9 +636,7 @@ LogicalResult ConvertLayoutOp::verify() {
 
 OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
   llvm::dbgs() << "\nSource from adaptor: " << adaptor.getSource() << "\n";
-  auto srcLayout = getInputLayoutAttr();
-  auto resLayout = getTargetLayoutAttr();
-  if (srcLayout == resLayout)
+  if (getInputLayout() == getTargetLayout())
     return adaptor.getSource();
   return {};
 }
@@ -647,9 +645,7 @@ struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
   using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
                                 PatternRewriter &rewriter) const override {
-    auto inputLayout = op.getInputLayoutAttr();
-    auto targetLayout = op.getTargetLayoutAttr();
-    if (inputLayout != targetLayout)
+    if (op.getInputLayout() != op.getTargetLayout())
       return failure();
     rewriter.replaceOp(op, op.getSource());
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3950e8f70d1ca..bf6d0b3164e16 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -78,6 +78,17 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   }
 }
 
+struct ConvertLayoutOpPattern: public OpRewritePattern<xegpu::ConvertLayoutOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override {
+    xegpu::LayoutAttr input_layout = op.getInputLayoutAttr().dropInstData();
+    xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr().dropInstData();
+    auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 //===------------------------------------------------------------------------===//
 // The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
 // to partition operations that process large shapes into multiple operations on

>From 149aeeaa3148f98d378177ccb64c8941a41d8dd4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 25 Jun 2025 15:00:36 +0000
Subject: [PATCH 3/5] add WgToSg pattern for convert layout

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  3 +-
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  1 +
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 34 +++++++++++++++++--
 3 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 54b1e360d11f1..00fe251f48757 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -635,9 +635,8 @@ LogicalResult ConvertLayoutOp::verify() {
 }
 
 OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
-  llvm::dbgs() << "\nSource from adaptor: " << adaptor.getSource() << "\n";
   if (getInputLayout() == getTargetLayout())
-    return adaptor.getSource();
+    return getSource();
   return {};
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index bf6d0b3164e16..3472bceca40ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -346,6 +346,7 @@ void XeGPUBlockingPass::runOnOperation() {
   });
 
   RewritePatternSet patterns(ctx);
+  patterns.add<ConvertLayoutOpPattern>(ctx);
 
   vector::UnrollVectorOptions vectorOptions;
   vectorOptions.setNativeShapeFn(options.nativeShape);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e3563d10bc6f1..fa45169021581 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -390,6 +390,31 @@ struct WgToSgElementwiseOp : public ConversionPattern {
   }
 };
 
+struct WgToSgConvertLayoutOp
+    : public OpConversionPattern<xegpu::ConvertLayoutOp> {
+  using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::LayoutAttr input = op.getInputLayout();
+    xegpu::LayoutAttr target = op.getTargetLayout();
+    if (input.getSgLayout() == target.getSgLayout() &&
+        input.getSgData() == target.getSgData()) {
+      input = input.dropSgLayoutAndData();
+      target = target.dropSgLayoutAndData();
+      SmallVector<Value> newOps;
+      for (auto src : adaptor.getSource()) {
+        auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
+            op.getLoc(), src.getType(), src, input, target);
+        newOps.push_back(newOp);
+      }
+      rewriter.replaceOpWithMultiple(op, newOps);
+      return success();
+    }
+    return failure();
+  }
+};
+
 // Handles UnrealizedConversionCastOp generated during
 // SCFStructuralTypeConversions (step 1). This op may appear as either a
 // target or source materialization for Vector values, e.g.:
@@ -473,8 +498,8 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
-      patterns.getContext());
+               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
+               WgToSgConvertLayoutOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -581,6 +606,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
+      [=](xegpu::ConvertLayoutOp op) -> bool {
+        return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
+      });
+
   target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
       [=](Operation *op) -> std::optional<bool> {
         // Only handle elementwise mappable ops

>From aee53c4cff7abc4665598c8ee9689456cc373889 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 27 Jun 2025 23:26:05 +0000
Subject: [PATCH 4/5] improve ConvertLayoutOpPattern

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  14 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 206 ++++++++++++++----
 2 files changed, 168 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7ef61de190b4c..6249d0484c215 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -313,13 +313,13 @@ LogicalResult TensorDescType::verify(
   if (rank != 1 && rank != 2)
     return emitError() << "expected 1D or 2D tensor";
 
-  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
-  if (blockAttr) {
-    MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
-    if (rank == 2 && memorySpaceAttr &&
-        memorySpaceAttr.getValue() == MemorySpace::SLM)
-      return emitError() << "SLM is not supported for 2D block tensor";
-  }
+  // auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
+  // if (blockAttr) {
+  //   MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
+  //   if (rank == 2 && memorySpaceAttr &&
+  //       memorySpaceAttr.getValue() == MemorySpace::SLM)
+  //     return emitError() << "SLM is not supported for 2D block tensor";
+  // }
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index fa45169021581..d542fb219a7c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -57,6 +57,39 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// Calculate offset for each subgroup
+static SmallVector<OpFoldResult>
+calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+                       const SmallVector<OpFoldResult> &originalOffsets,
+                       const SmallVector<Value> &localOffset,
+                       const SmallVector<int64_t> &distUnitBaseAddr,
+                       const SmallVector<int64_t> &distUnitShape) {
+  assert(localOffset.size() == distUnitBaseAddr.size() &&
+         "localOffset and distUnitBaseAddr must have the same rank");
+
+  SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+                                          originalOffsets.end());
+  size_t rank = localOffset.size();
+  for (size_t i = 0; i < rank; ++i) {
+    size_t dimIdx = originalOffsets.size() - rank + i;
+    Value constOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
+    Value offset =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+    Value modValue =
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
+    Value offsetMod =
+        rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
+    Value origOffset =
+        getValueOrCreateConstantIndexOp(rewriter, loc, originalOffsets[dimIdx]);
+    Value globalOffset =
+        rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
+    globalOffsets[dimIdx] = globalOffset;
+  }
+
+  return globalOffsets;
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -105,39 +138,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Calculate offset for each subgroup
-  SmallVector<OpFoldResult>
-  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
-                         const SmallVector<OpFoldResult> &originalOffsets,
-                         const SmallVector<Value> &localOffset,
-                         const SmallVector<int64_t> &distUnitBaseAddr,
-                         const SmallVector<int64_t> &distUnitShape) const {
-    assert(localOffset.size() == distUnitBaseAddr.size() &&
-           "localOffset and distUnitBaseAddr must have the same rank");
-
-    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
-                                            originalOffsets.end());
-    size_t rank = localOffset.size();
-    for (size_t i = 0; i < rank; ++i) {
-      size_t dimIdx = originalOffsets.size() - rank + i;
-      Value constOffset =
-          rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
-      Value offset =
-          rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
-      Value modValue =
-          rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
-      Value offsetMod =
-          rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
-      Value origOffset = getValueOrCreateConstantIndexOp(
-          rewriter, loc, originalOffsets[dimIdx]);
-      Value globalOffset =
-          rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
-      globalOffsets[dimIdx] = globalOffset;
-    }
-
-    return globalOffsets;
-  }
-
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -390,6 +390,21 @@ struct WgToSgElementwiseOp : public ConversionPattern {
   }
 };
 
+// based on the size of the given vector type
+static TypedValue<MemRefType>
+allocateSLMBuffer(ConversionPatternRewriter &rewriter, Location loc,
+                  VectorType type) {
+  int64_t bits = type.getElementType().getIntOrFloatBitWidth();
+  int64_t slmSizeInBytes = type.getNumElements() * bits / 8;
+  auto slmTy = MemRefType::get(slmSizeInBytes, rewriter.getI8Type(), {}, 3);
+  auto slm = rewriter.create<memref::AllocOp>(loc, slmTy);
+  auto viewTy = MemRefType::get(type.getShape(), type.getElementType(), {}, 3);
+  auto view = rewriter.create<memref::ViewOp>(
+      loc, viewTy, slm, rewriter.create<arith::ConstantIndexOp>(loc, 0),
+      ValueRange());
+  return view;
+}
+
 struct WgToSgConvertLayoutOp
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
   using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
@@ -398,20 +413,121 @@ struct WgToSgConvertLayoutOp
                   ConversionPatternRewriter &rewriter) const override {
     xegpu::LayoutAttr input = op.getInputLayout();
     xegpu::LayoutAttr target = op.getTargetLayout();
-    if (input.getSgLayout() == target.getSgLayout() &&
-        input.getSgData() == target.getSgData()) {
-      input = input.dropSgLayoutAndData();
-      target = target.dropSgLayoutAndData();
-      SmallVector<Value> newOps;
-      for (auto src : adaptor.getSource()) {
-        auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
-            op.getLoc(), src.getType(), src, input, target);
-        newOps.push_back(newOp);
+
+    if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Input and target layouts must have subgroup layout");
+
+    // initialize values with the source values
+    SmallVector<Value> values(adaptor.getSource());
+
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    VectorType type = op.getResult().getType();
+    ArrayRef<int64_t> shape = type.getShape();
+
+    DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
+    DenseI32ArrayAttr inputSgData = input.getSgData();
+    DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
+    DenseI32ArrayAttr targetSgData = target.getSgData();
+
+    // we only need SLM support when input and target layouts are different
+    if (inputSgLayout != targetSgLayout || inputSgData != targetSgData) {
+      values.clear();
+      rewriter.setInsertionPoint(op);
+      TypedValue<MemRefType> slmBuffer = allocateSLMBuffer(rewriter, loc, type);
+
+      auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(
+          loc, rewriter.getIndexType(), nullptr);
+
+      { // store to slm buffer
+        SmallVector<int64_t> sgLayout =
+            llvm::to_vector_of<int64_t>(input.getSgLayout().asArrayRef());
+        SmallVector<int64_t> sgShape = getSgShapeAndCount(shape, input).first;
+        auto delinearized = affine::delinearizeIndex(
+            rewriter, loc, linearSgId, getAsIndexOpFoldResult(ctx, sgLayout));
+        if (failed(delinearized))
+          return rewriter.notifyMatchFailure(op, "Failed to delinearize sgId");
+        SmallVector<Value> sgIds = *delinearized;
+
+        SmallVector<int64_t> distUnitShape(sgLayout.size());
+        SmallVector<Value> localOffset(sgLayout.size());
+        for (size_t i = 0; i < sgLayout.size(); i++) {
+          distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], shape[i]);
+          localOffset[i] = rewriter.createOrFold<index::MulOp>(
+              loc, sgIds[i],
+              rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]));
+        }
+
+        auto tdescTy = xegpu::TensorDescType::get(
+            sgShape, type.getElementType(), 1, false, xegpu::MemorySpace::SLM,
+            input.dropSgLayoutAndData());
+
+        SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult(
+            ctx, SmallVector<int64_t>(sgLayout.size(), 0));
+        for (auto [data, baseOffsets] :
+             llvm::zip_equal(adaptor.getSource(),
+                             StaticTileOffsetRange(shape, distUnitShape))) {
+          SmallVector<OpFoldResult> offsets = calculateGlobalOffsets(
+              rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
+          auto tdesc = rewriter.create<xegpu::CreateNdDescOp>(
+              loc, tdescTy, slmBuffer, offsets);
+          rewriter.create<xegpu::StoreNdOp>(loc, data, tdesc, nullptr, nullptr,
+                                            nullptr);
+        }
+      }
+
+      rewriter.create<gpu::BarrierOp>(loc);
+
+      { // load from SLM
+        SmallVector<int64_t> sgLayout =
+            llvm::to_vector_of<int64_t>(target.getSgLayout().asArrayRef());
+        SmallVector<int64_t> sgShape = getSgShapeAndCount(shape, target).first;
+        auto delinearized = affine::delinearizeIndex(
+            rewriter, loc, linearSgId, getAsIndexOpFoldResult(ctx, sgLayout));
+        if (failed(delinearized))
+          return rewriter.notifyMatchFailure(op, "Failed to delinearize sgId");
+        SmallVector<Value> sgIds = *delinearized;
+
+        SmallVector<int64_t> distUnitShape(sgLayout.size());
+        SmallVector<Value> localOffset(sgLayout.size());
+        for (size_t i = 0; i < sgLayout.size(); i++) {
+          distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], shape[i]);
+          localOffset[i] = rewriter.createOrFold<index::MulOp>(
+              loc, sgIds[i],
+              rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]));
+        }
+
+        auto tdescTy = xegpu::TensorDescType::get(
+            sgShape, type.getElementType(), 1, false, xegpu::MemorySpace::SLM,
+            target.dropSgLayoutAndData());
+        auto valueTy = VectorType::get(sgShape, type.getElementType());
+
+        SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult(
+            ctx, SmallVector<int64_t>(sgLayout.size(), 0));
+        for (auto baseOffsets : StaticTileOffsetRange(shape, distUnitShape)) {
+          SmallVector<OpFoldResult> offsets = calculateGlobalOffsets(
+              rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
+          auto tdesc = rewriter.create<xegpu::CreateNdDescOp>(
+              loc, tdescTy, slmBuffer, offsets);
+          auto newOp = rewriter.create<xegpu::LoadNdOp>(
+              loc, TypeRange({valueTy}), ValueRange({tdesc}));
+          values.push_back(newOp);
+        }
       }
-      rewriter.replaceOpWithMultiple(op, newOps);
-      return success();
     }
-    return failure();
+
+    input = input.dropSgLayoutAndData();
+    target = target.dropSgLayoutAndData();
+
+    SmallVector<Value> newOps;
+    for (auto src : values) {
+      auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
+          op.getLoc(), src.getType(), src, input, target);
+      newOps.push_back(newOp);
+    }
+    rewriter.replaceOpWithMultiple(op, newOps);
+    return success();
   }
 };
 

>From c416cec159b701fbd405b049be1330f6ee24afc7 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 30 Jun 2025 15:07:33 +0000
Subject: [PATCH 5/5] code format

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3472bceca40ce..06e0c6105df58 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -78,12 +78,15 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   }
 }
 
-struct ConvertLayoutOpPattern: public OpRewritePattern<xegpu::ConvertLayoutOp> {
+struct ConvertLayoutOpPattern
+    : public OpRewritePattern<xegpu::ConvertLayoutOp> {
   using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
+                                PatternRewriter &rewriter) const override {
     xegpu::LayoutAttr input_layout = op.getInputLayoutAttr().dropInstData();
     xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr().dropInstData();
-    auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
+    auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
+        op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
     rewriter.replaceOp(op, newOp);
     return success();
   }



More information about the Mlir-commits mailing list