[Mlir-commits] [mlir] d521863 - [MLIR][XeGPU] Add support for Convert Layout from Wg to Sg (#178922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 13 09:02:25 PST 2026


Author: Nishant Patel
Date: 2026-02-13T09:02:19-08:00
New Revision: d521863f089a1eb7ea35a1bf73eb9fde30c2c49c

URL: https://github.com/llvm/llvm-project/commit/d521863f089a1eb7ea35a1bf73eb9fde30c2c49c
DIFF: https://github.com/llvm/llvm-project/commit/d521863f089a1eb7ea35a1bf73eb9fde30c2c49c.diff

LOG: [MLIR][XeGPU] Add support for Convert Layout from Wg to Sg (#178922)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/invalid.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 7badfaf4a8216..f85dde0b98c8e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -27,6 +27,10 @@ class TensorDescType;
 class DistributeLayoutAttr;
 class LayoutAttr;
 class SliceAttr;
+
+/// Specifies the level of a layout hierarchy for comparison or propagation.
+enum class LayoutKind { Lane, InstData, Subgroup };
+
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 898fb7e1d8e6d..377967dfdb1e5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -311,7 +311,32 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     /*methodName=*/"isSliceOf",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
 
-    InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
+    InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
+                     at a specific level of the layout hierarchy. Unlike isEqualTo,
+                     this compares only the effective (non-sliced) fields at the
+                     requested level.}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isCompatibleWith",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
+                                  "xegpu::LayoutKind": $level),
+                    /*methodBody=*/[{
+                      if (!other)
+                        return false;
+                      switch (level) {
+                        case xegpu::LayoutKind::Subgroup:
+                          return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
+                                 $_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
+                        case xegpu::LayoutKind::InstData:
+                          return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
+                        case xegpu::LayoutKind::Lane:
+                          return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
+                                 $_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+                      }
+                      return false;
+                    }]>,
+    InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
+                     For LayoutAttr, this compares all fields.
+                     For SliceAttr, this requires the same parent and same sliced dims.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isEqualTo",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
@@ -546,9 +571,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
 
-    /// Check if this is identical to some other layout.
+    /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
-
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -734,11 +758,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is equal to another layout.
+    bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+
     /// Drop the slice dims to get the original layout.
     SliceAttr dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop);
-
-    /// Check if this is identical to some other layout.
-    bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   }];
 
   let description = [{
-    This operation loads a 2D block of data from shared local memory (SLM) as specified
-    by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation loads an nD block of data from shared local memory (SLM) as specified
+    by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
   let description = [{
-    This operation stores a 2D `data` fragment into the shared local memory region
-    specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation stores an nD `data` fragment into the shared local memory region
+    specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 0d5210f07f05a..35dc9541b5755 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -33,8 +33,6 @@ class TensorDescType;
 
 namespace xegpu {
 
-enum class LayoutKind { Lane, InstData, Subgroup };
-
 LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
                                LayoutKind layoutKind, bool printOnly = false);
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d99557e68f0ec..7ace00a746e21 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -755,6 +755,17 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
 
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+  if (dyn_cast<xegpu::LayoutAttr>(other))
+    return false;
+
+  auto flattenedThis = flatten();
+  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+  return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+          (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
 xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
   if (sliceDimsToDrop.empty())
     return *this;
@@ -773,17 +784,6 @@ xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
   return sliceWithoutDims;
 }
 
-bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
-  if (dyn_cast<xegpu::LayoutAttr>(other))
-    return false;
-
-  auto flattenedThis = flatten();
-  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
-
-  return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
-          (flattenedThis.getDims() == flattenedOther.getDims()));
-}
-
 // Helper function to adjust dimensions from sliced space to parent space
 // say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
 // shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       return success();
   }
 
-  if (mdescTy.getRank() != 2)
-    return emitError() << "mem_desc must be 2D.";
+  if (mdescTy.getRank() < 2)
+    return emitError() << "mem_desc must be 2D or greater.";
 
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 566fb45f5bcb4..66eb7fc97aa1a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -602,44 +602,110 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 struct WgToSgConvertLayoutOp
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
   using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
 
-    auto input = op.getInputLayout();
-    auto target = op.getTargetLayout();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    auto inputLayout = op.getInputLayout();
+    auto targetLayout = op.getTargetLayout();
 
-    if (!input || !target || !input.isForWorkgroup() ||
-        !target.isForWorkgroup())
+    if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+        !targetLayout.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
-    SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr inputOrder = input.getOrder();
-    SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr targetOrder = target.getOrder();
-
-    // TODO: currently we only support for optimal case, where input and
-    // output has the same sg_layout and sg_data, so SLM is not involved.
-    if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
-        inputOrder != targetOrder)
+    SmallVector<int64_t> inputSgLayout =
+        inputLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> targetSgLayout =
+        targetLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+    // Fast path: if sg_layout and sg_data are identical, no SLM needed
+    if (inputLayout.isCompatibleWith(targetLayout,
+                                     xegpu::LayoutKind::Subgroup)) {
+      inputLayout = inputLayout.dropSgLayoutAndData();
+      targetLayout = targetLayout.dropSgLayoutAndData();
+
+      SmallVector<Value> newOps(adaptor.getSource());
+      if (inputLayout && targetLayout) {
+        for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+          auto newOp = xegpu::ConvertLayoutOp::create(
+              rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+          newOps[i] = newOp;
+        }
+      }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+      return success();
+    }
+
+    // SLM path: layouts 
diff er, need cross-subgroup data redistribution
+    Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+    SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+    // Calculate SLM size requirements
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+    // Allocate SLM
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+                                               elemTy, nullptr);
+    auto memDesc =
+        xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+    auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), nullptr);
+
+    // STORE PHASE: Each subgroup stores in SLM using input layout
+    auto storeCoords = inputLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(storeCoords))
       return failure();
 
-    input = input.dropSgLayoutAndData();
-    target = target.dropSgLayoutAndData();
+    // Store to SLM
+    for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+      SmallVector<OpFoldResult> storeMatrixOffsets;
+      for (Value coord : coords) {
+        storeMatrixOffsets.push_back(coord);
+      }
+      xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+                                   storeMatrixOffsets, nullptr /*layout*/);
+    }
 
-    SmallVector<Value> newOps(adaptor.getSource());
-    if (input && target) {
-      // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
-      for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
-        auto newOp = xegpu::ConvertLayoutOp::create(
-            rewriter, op.getLoc(), src.getType(), src, input, target);
-        newOps[i] = newOp;
+    gpu::BarrierOp::create(rewriter, loc);
+
+    // LOAD PHASE: Each target subgroup loads from SLM using target layout
+    auto loadCoords = targetLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(loadCoords))
+      return failure();
+
+    VectorType loadType = VectorType::get(targetSgData, elemTy);
+
+    // Load vectors from SLM
+    SmallVector<Value> finalResults;
+    for (auto coords : *loadCoords) {
+      SmallVector<OpFoldResult> loadMatrixOffsets;
+      for (Value coord : coords) {
+        loadMatrixOffsets.push_back(coord);
       }
+      auto loadOp = xegpu::LoadMatrixOp::create(
+          rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+          targetLayout.dropSgLayoutAndData());
+
+      finalResults.push_back(loadOp.getResult());
     }
-    rewriter.replaceOpWithMultiple(op, {newOps});
+
+    rewriter.replaceOpWithMultiple(op, {finalResults});
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
 
 // -----
 func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
-  // expected-error at +1 {{mem_desc must be 2D}}
+  // expected-error at +1 {{mem_desc must be 2D or greater}}
   %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
   return
 }
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
 
 // -----
 func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
-  // expected-error at +1 {{mem_desc must be 2D.}}
+  // expected-error at +1 {{mem_desc must be 2D or greater}}
   xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
   return
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 69a2ca7c49c2d..e2e94c5f0300f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -828,6 +828,112 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: convert_layout_no_slm
+  gpu.func @convert_layout_no_slm(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+    %c32 = arith.constant 32 : index
+    %c4096 = arith.constant 4096 : index
+    %c0 = arith.constant 0 : index
+    %c256 = arith.constant 256 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c256 overflow<nsw> : index
+    %1 = arith.muli %block_id_y, %c256 overflow<nsw> : index
+    %2 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
+    %3 = xegpu.load_nd %2[%0, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>> -> vector<256x256xf32>
+    %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>>
+    %5 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>>
+    %6 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %3) -> (vector<256x256xf32>) {
+      %7 = xegpu.load_nd %4[%0, %arg3] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>}> : !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+      %8 = xegpu.load_nd %5[%arg3, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>}> : !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<32x256xf16>
+      // CHECK: %[[CONVERT_A:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<32x32xf16>
+      // CHECK: %[[CONVERT_B:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}> : vector<32x32xf16>
+      %9 = xegpu.convert_layout %7 <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : vector<256x32xf16>
+      %10 = xegpu.convert_layout %8 <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>}> : vector<32x256xf16>
+      %11 = xegpu.dpas %9, %10, %arg4 {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32>
+      scf.yield %11 : vector<256x256xf32>
+    } {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+    xegpu.store_nd %6, %2[%0, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
+    gpu.return
+  }
+  
+  // CHECK-LABEL: convert_layout_slm
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+  gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]],  %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+    // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+    %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+    gpu.return
+  }
+
+  gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x32x16xindex>
+    // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<1x32x16xi1>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+    %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+    gpu.return
+  }
+
   // CHECK-LABEL: distribute_nested_slice
   // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
   // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>


        


More information about the Mlir-commits mailing list