[Mlir-commits] [mlir] e9fc393 - [MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 21 00:08:16 PST 2025


Author: Tuomas Kärnä
Date: 2025-11-21T10:08:12+02:00
New Revision: e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34

URL: https://github.com/llvm/llvm-project/commit/e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34
DIFF: https://github.com/llvm/llvm-project/commit/e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34.diff

LOG: [MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)

`set_op_layout_attr` and `set_desc_layout` transform ops wrap
`xegpu.layout` in an `xegpu.slice` attribute if `slice_dims` argument is
set.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
    mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
    mlir/python/mlir/dialects/transform/xegpu.py
    mlir/test/Dialect/XeGPU/transform-ops.mlir
    mlir/test/python/dialects/transform_xegpu_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 16044838aa27d..29579acc727ed 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -42,10 +42,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
 
   let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
   let description = [{
-    Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
-    attribute to the result tensor descriptor. The layout is defined by the
-    `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
-    to the transformed op.
+    Given an `xegpu.create_nd_desc` operation, this transform adds
+    `xegpu.layout` attribute to the result tensor descriptor. The layout is
+    defined by the `sg_layout`, and `sg_data` and optional `inst_data`
+    attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
+    wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to
+    the transformed op.
   }];
 
   let arguments = (ins
@@ -55,7 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
                    Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
-                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
                    );
 
   let results = (outs TransformHandleTypeInterface:$transformed);
@@ -63,7 +66,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedSgData,
-                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   "ArrayRef<OpFoldResult>":$mixedInstData,
+                   "ArrayRef<int64_t>":$sliceDims
                    )>,
   ];
 
@@ -72,6 +76,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    (`slice_dims` `=` $slice_dims^)?
     attr-dict `:` functional-type(operands, results)
   }];
 
@@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
     Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
     `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
     target operand/result value is defined by the `index` argument. The layout
-    is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+    is defined by the `sg_layout`, `sg_data` and optional `inst_data`
+    attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
+    wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
@@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
                    DefaultValuedAttr<UnitAttr, "false">:$result
                    );
 
@@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedSgData,
                    "ArrayRef<OpFoldResult>":$mixedInstData,
+                   "ArrayRef<int64_t>":$sliceDims,
                    CArg<"bool", "false">:$result
                    )>,
   ];
@@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    (`slice_dims` `=` $slice_dims^)?
     attr-dict `:` qualified(type(operands))
   }];
 

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index e301d4d9bd108..8995ab3082d24 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -167,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
 /// Replace xegpu.create_nd_desc op with a new one with the given layout.
 static xegpu::CreateNdDescOp
 setDescLayout(transform::TransformRewriter &rewriter,
-              xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
+              xegpu::CreateNdDescOp descOp,
+              xegpu::DistributeLayoutAttr layout) {
   assert(descOp.getMixedOffsets().size() == 0 &&
          "create desc op with offsets is not supported");
   auto oldTensorDesc = descOp.getType();
@@ -212,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
                                        OperationState &result, Value target,
                                        ArrayRef<OpFoldResult> mixedSgLayout,
                                        ArrayRef<OpFoldResult> mixedSgData,
-                                       ArrayRef<OpFoldResult> mixedInstData) {
+                                       ArrayRef<OpFoldResult> mixedInstData,
+                                       ArrayRef<int64_t> sliceDims) {
   SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
   SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
   dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -225,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
         /*inst_data=*/dynamicInstData,
         /*static_sg_layout=*/staticSgLayout,
         /*static_sg_data=*/staticSgData,
-        /*static_inst_data=*/staticInstData);
+        /*static_inst_data=*/staticInstData,
+        /*slice_dims=*/sliceDims);
 }
 
 DiagnosedSilenceableFailure
@@ -246,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   if (!status.succeeded())
     return status;
 
+  xegpu::DistributeLayoutAttr layout = layoutAttr;
+  auto sliceDims = getSliceDims();
+  if (sliceDims.size() > 0) {
+    // Wrap layoutAttr in a slice attribute.
+    layout = xegpu::SliceAttr::get(
+        getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+  }
+
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
   if (!descOp) {
@@ -257,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+  auto newdescOp = setDescLayout(rewriter, descOp, layout);
 
   // Map result handles.
   results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
@@ -278,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects(
 void transform::SetOpLayoutAttrOp::build(
     OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
     ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
-    ArrayRef<OpFoldResult> mixedInstData, bool result) {
+    ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
+    bool result) {
   SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
   SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
   dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -293,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build(
         /*static_sg_layout=*/staticSgLayout,
         /*static_sg_data=*/staticSgData,
         /*static_inst_data=*/staticInstData,
+        /*slice_dims=*/sliceDims,
         /*result=*/result);
 }
 
@@ -326,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
   if (!status.succeeded())
     return status;
 
+  xegpu::DistributeLayoutAttr layout = layoutAttr;
+  auto sliceDims = getSliceDims();
+  if (sliceDims.size() > 0) {
+    // Wrap layoutAttr in a slice attribute.
+    layout = xegpu::SliceAttr::get(
+        getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+  }
+
   // Set layout attribute for the op result or operand
   if (resultTarget)
-    xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+    xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
   else
-    xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+    xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
   return DiagnosedSilenceableFailure::success();
 }
 

diff  --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 7169b5e28ab5e..5aa6453b7cb8a 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -62,6 +62,7 @@ def __init__(
         sg_data: MixedValues,
         *,
         inst_data: Optional[MixedValues] = None,
+        slice_dims: Optional[MixedInt] = None,
         loc=None,
         ip=None,
     ):
@@ -92,6 +93,7 @@ def __init__(
             static_sg_layout=static_sg_layout,
             static_sg_data=static_sg_data,
             static_inst_data=static_inst_data,
+            slice_dims=slice_dims,
             loc=loc,
             ip=ip,
         )
@@ -103,6 +105,7 @@ def set_desc_layout(
     sg_data: MixedValues,
     *,
     inst_data: Optional[MixedValues] = None,
+    slice_dims: Optional[MixedInt] = None,
     loc=None,
     ip=None,
 ) -> OpResult:
@@ -111,6 +114,7 @@ def set_desc_layout(
         sg_layout,
         sg_data,
         inst_data=inst_data,
+        slice_dims=slice_dims,
         loc=loc,
         ip=ip,
     ).result
@@ -127,6 +131,7 @@ def __init__(
         sg_data: MixedValues,
         *,
         inst_data: Optional[MixedValues] = None,
+        slice_dims: Optional[MixedInt] = None,
         index: Optional[Union[int, Attribute]] = None,
         result: Optional[Union[bool, Attribute]] = None,
         loc=None,
@@ -156,6 +161,7 @@ def __init__(
             static_sg_layout=static_sg_layout,
             static_sg_data=static_sg_data,
             static_inst_data=static_inst_data,
+            slice_dims=slice_dims,
             index=index,
             result=result,
             loc=loc,
@@ -169,6 +175,7 @@ def set_op_layout_attr(
     sg_data: MixedValues,
     *,
     inst_data: Optional[MixedValues] = None,
+    slice_dims: Optional[MixedInt] = None,
     index: Optional[Union[int, Attribute]] = None,
     result: Optional[Union[bool, Attribute]] = None,
     loc=None,
@@ -179,6 +186,7 @@ def set_op_layout_attr(
         sg_layout,
         sg_data,
         inst_data=inst_data,
+        slice_dims=slice_dims,
         index=index,
         result=result,
         loc=loc,

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index ff0accdec7532..561034fb5880b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_desc_layout_slice
+func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_op_layout_attr_result_default_index
 func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_op_layout_attr_result_slice
+func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
+  // CHECK: = arith.extf
+  // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
+  %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_op_layout_attr_operand_minimal
 func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>

diff  --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 4f89982ad1c44..2b11acb04ed5b 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -66,6 +66,25 @@ def setDescLayoutInstData():
     # CHECK: inst_data = [8, 16]
 
 
+ at run
+def setDescLayoutSlice():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.set_desc_layout(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutSlice
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: slice_dims = [0]
+
+
 @run
 def setOpLayoutAttrOperandMinimal():
     sequence = transform.SequenceOp(
@@ -106,13 +125,41 @@ def setOpLayoutAttrResult():
             result=True,
         )
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setOpLayoutAttr
+    # CHECK-LABEL: TEST: setOpLayoutAttrResult
+    # CHECK: transform.xegpu.set_op_layout_attr %
+    # NO-CHECK: index = 0
+    # CHECK: result
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setOpLayoutAttrResultSlice():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.set_op_layout_attr(
+            sequence.bodyTarget,
+            index=0,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+            slice_dims=[0],
+            result=True,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
     # CHECK: transform.xegpu.set_op_layout_attr %
     # NO-CHECK: index = 0
     # CHECK: result
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+    # CHECK: slice_dims = [0]
 
 
 @run


        


More information about the Mlir-commits mailing list