[Mlir-commits] [mlir] e9fc393 - [MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 21 00:08:16 PST 2025
Author: Tuomas Kärnä
Date: 2025-11-21T10:08:12+02:00
New Revision: e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34
URL: https://github.com/llvm/llvm-project/commit/e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34
DIFF: https://github.com/llvm/llvm-project/commit/e9fc393a9e431d1a0aebc3fe448f3cf1668fbb34.diff
LOG: [MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)
`set_op_layout_attr` and `set_desc_layout` transform ops wrap
`xegpu.layout` in an `xegpu.slice` attribute if `slice_dims` argument is
set.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 16044838aa27d..29579acc727ed 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -42,10 +42,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
let description = [{
- Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
- attribute to the result tensor descriptor. The layout is defined by the
- `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
- to the transformed op.
+ Given an `xegpu.create_nd_desc` operation, this transform adds
+ `xegpu.layout` attribute to the result tensor descriptor. The layout is
+ defined by the `sg_layout`, and `sg_data` and optional `inst_data`
+ attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
+ wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to
+ the transformed op.
}];
let arguments = (ins
@@ -55,7 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
);
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -63,7 +66,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
- "ArrayRef<OpFoldResult>":$mixedInstData
+ "ArrayRef<OpFoldResult>":$mixedInstData,
+ "ArrayRef<int64_t>":$sliceDims
)>,
];
@@ -72,6 +76,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ (`slice_dims` `=` $slice_dims^)?
attr-dict `:` functional-type(operands, results)
}];
@@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
target operand/result value is defined by the `index` argument. The layout
- is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+ is defined by the `sg_layout`, `sg_data` and optional `inst_data`
+ attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
+ wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
DefaultValuedAttr<UnitAttr, "false">:$result
);
@@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
+ "ArrayRef<int64_t>":$sliceDims,
CArg<"bool", "false">:$result
)>,
];
@@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ (`slice_dims` `=` $slice_dims^)?
attr-dict `:` qualified(type(operands))
}];
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index e301d4d9bd108..8995ab3082d24 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -167,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
- xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
+ xegpu::CreateNdDescOp descOp,
+ xegpu::DistributeLayoutAttr layout) {
assert(descOp.getMixedOffsets().size() == 0 &&
"create desc op with offsets is not supported");
auto oldTensorDesc = descOp.getType();
@@ -212,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
- ArrayRef<OpFoldResult> mixedInstData) {
+ ArrayRef<OpFoldResult> mixedInstData,
+ ArrayRef<int64_t> sliceDims) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -225,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
- /*static_inst_data=*/staticInstData);
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims);
}
DiagnosedSilenceableFailure
@@ -246,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
@@ -257,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attr in desc op's return type. Replaces old desc op.
- auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+ auto newdescOp = setDescLayout(rewriter, descOp, layout);
// Map result handles.
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
@@ -278,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects(
void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
- ArrayRef<OpFoldResult> mixedInstData, bool result) {
+ ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
+ bool result) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -293,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims,
/*result=*/result);
}
@@ -326,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
// Set layout attribute for the op result or operand
if (resultTarget)
- xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
else
- xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 7169b5e28ab5e..5aa6453b7cb8a 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -62,6 +62,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
):
@@ -92,6 +93,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
+ slice_dims=slice_dims,
loc=loc,
ip=ip,
)
@@ -103,6 +105,7 @@ def set_desc_layout(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
) -> OpResult:
@@ -111,6 +114,7 @@ def set_desc_layout(
sg_layout,
sg_data,
inst_data=inst_data,
+ slice_dims=slice_dims,
loc=loc,
ip=ip,
).result
@@ -127,6 +131,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
loc=None,
@@ -156,6 +161,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
+ slice_dims=slice_dims,
index=index,
result=result,
loc=loc,
@@ -169,6 +175,7 @@ def set_op_layout_attr(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
loc=None,
@@ -179,6 +186,7 @@ def set_op_layout_attr(
sg_layout,
sg_data,
inst_data=inst_data,
+ slice_dims=slice_dims,
index=index,
result=result,
loc=loc,
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index ff0accdec7532..561034fb5880b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_desc_layout_slice
+func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_result_default_index
func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_op_layout_attr_result_slice
+func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
+ // CHECK: = arith.extf
+ // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
+ %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_operand_minimal
func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 4f89982ad1c44..2b11acb04ed5b 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -66,6 +66,25 @@ def setDescLayoutInstData():
# CHECK: inst_data = [8, 16]
+ at run
+def setDescLayoutSlice():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(
+ sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutSlice
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: slice_dims = [0]
+
+
@run
def setOpLayoutAttrOperandMinimal():
sequence = transform.SequenceOp(
@@ -106,13 +125,41 @@ def setOpLayoutAttrResult():
result=True,
)
transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttr
+ # CHECK-LABEL: TEST: setOpLayoutAttrResult
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # NO-CHECK: index = 0
+ # CHECK: result
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setOpLayoutAttrResultSlice():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ slice_dims=[0],
+ result=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
+ # CHECK: slice_dims = [0]
@run
More information about the Mlir-commits
mailing list