[Mlir-commits] [mlir] 48566b2 - [MLIR][XeGPU][TransformOps] set_op_layout_attr supports setting anchor layout (#172542)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 12 21:50:05 PST 2026
Author: Tuomas Kärnä
Date: 2026-02-13T07:49:59+02:00
New Revision: 48566b21a48573e0b2dcc67936ea8fdb7c40974a
URL: https://github.com/llvm/llvm-project/commit/48566b21a48573e0b2dcc67936ea8fdb7c40974a
DIFF: https://github.com/llvm/llvm-project/commit/48566b21a48573e0b2dcc67936ea8fdb7c40974a.diff
LOG: [MLIR][XeGPU][TransformOps] set_op_layout_attr supports setting anchor layout (#172542)
Changes `transform.xegpu.set_op_layout_attr` to support xegpu anchor
layouts. By default, if `result` and `operand` bool arguments are unset,
this transform op sets the op's anchor layout, if the op supports it
(otherwise emits a silenceable failure).
In contrast to the earlier implementation, setting the operand layout
now requires setting the new `operand` argument.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 29579acc727ed..23dabe4eb380a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -109,12 +109,14 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
let summary = "Set xegpu.layout attribute of an op.";
let description = [{
- Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
- `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
- target operand/result value is defined by the `index` argument. The layout
- is defined by the `sg_layout`, `sg_data` and optional `inst_data`
- attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
- wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
+ Sets the `xegpu.layout` attribute of an op. By default it sets the anchor
+ layout for XeGPU ops that support it. If `result=true` or `operand=true`,
+ it sets the `layout_result_{index}` or `layout_operand_{index}` attribute,
+ respectively, applicable to any op. The target operand/result value is
+ defined by the `index` argument. The layout is defined by the `sg_layout`,
+ `sg_data` and optional `inst_data` attributes. If `slice_dims` is provided,
+ the `xegpu.layout` attribute is wrapped in an `xegpu.slice<..., dims=slice_dims>`
+ attribute.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -126,7 +128,8 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
- DefaultValuedAttr<UnitAttr, "false">:$result
+ DefaultValuedAttr<UnitAttr, "false">:$result,
+ DefaultValuedAttr<UnitAttr, "false">:$operand
);
let results = (outs);
@@ -137,12 +140,13 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int64_t>":$sliceDims,
- CArg<"bool", "false">:$result
+ CArg<"bool", "false">:$result,
+ CArg<"bool", "false">:$operand
)>,
];
let assemblyFormat = [{
- $target (`result` $result^)? (`index` `=` $index^)?
+ $target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
@@ -169,6 +173,8 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
return getMixedValues(getStaticInstData(), getInstData(), b);
}
}];
+
+ let hasVerifier = 1;
}
def SetGPULaunchThreadsOp
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 285dfac56be08..7bc67da8263dc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -291,7 +291,7 @@ void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
- bool result) {
+ bool result, bool operand) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -307,7 +307,8 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*slice_dims=*/sliceDims,
- /*result=*/result);
+ /*result=*/result,
+ /*operand=*/operand);
}
DiagnosedSilenceableFailure
@@ -322,13 +323,14 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
Operation *target = *targetOps.begin();
bool resultTarget = getResult();
+ bool operandTarget = getOperand();
int64_t index = getIndex();
if (resultTarget && index >= target->getNumResults()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op results";
}
- if (!resultTarget && index >= target->getNumOperands()) {
+ if (operandTarget && index >= target->getNumOperands()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op operands";
}
@@ -348,11 +350,38 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
}
- // Set layout attribute for the op result or operand
- if (resultTarget)
+ // Set layout attribute
+ if (resultTarget) {
+ // op result
xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
- else
+ } else if (operandTarget) {
+ // op operand
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+ } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
+ // dpas op is a special case where layout needs to be set for A, B, and C
+ if (index == 0)
+ dpasOp.getProperties().layout_a = layout;
+ else if (index == 1)
+ dpasOp.getProperties().layout_b = layout;
+ else if (index == 2)
+ dpasOp.getProperties().layout_cd = layout;
+ else {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Invalid index for setting dpas op layout: " << index;
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ } else {
+ // op's anchor layout.
+ auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
+ if (!anchorOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Cannot set anchor layout to op: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ anchorOp.setAnchorLayout(layout);
+ }
return DiagnosedSilenceableFailure::success();
}
@@ -365,6 +394,13 @@ void transform::SetOpLayoutAttrOp::getEffects(
modifiesPayload(effects);
}
+LogicalResult transform::SetOpLayoutAttrOp::verify() {
+ if (getResult() && getOperand()) {
+ return emitOpError("Cannot set both result and operand simultaneously.");
+ }
+ return success();
+}
+
void transform::SetGPULaunchThreadsOp::build(
OpBuilder &builder, OperationState &ostate, Value target,
ArrayRef<OpFoldResult> mixedThreads) {
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 5aa6453b7cb8a..a768ce5f4e720 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -134,6 +134,7 @@ def __init__(
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
+ operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
@@ -164,6 +165,7 @@ def __init__(
slice_dims=slice_dims,
index=index,
result=result,
+ operand=operand,
loc=loc,
ip=ip,
)
@@ -178,6 +180,7 @@ def set_op_layout_attr(
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
+ operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
) -> SetOpLayoutAttrOp:
@@ -189,6 +192,7 @@ def set_op_layout_attr(
slice_dims=slice_dims,
index=index,
result=result,
+ operand=operand,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index dce4a41982550..2a147497a893b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -47,7 +47,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error at below {{Index exceeds the number of op operands}}
- transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
@@ -67,6 +67,25 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_not_anchor_op
+func.func @set_op_layout_attr_not_anchor_op(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Cannot set anchor layout to op: arith.extf}}
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 13ed24ebf0a3a..9a278cbf7b498 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -140,23 +140,19 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @set_op_layout_attr_result_default_index
-func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_op_layout_attr_result_default
+func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
- %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
- %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
- %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
- // CHECK: = xegpu.dpas
- // CHECK-SAME: {layout_cd = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
transform.yield
@@ -210,27 +206,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @set_op_layout_attr_result0
-func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.extf %1
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: @set_op_layout_attr_result_slice
func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
// CHECK: = arith.extf
@@ -264,7 +239,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
@@ -287,7 +262,102 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor
+func.func @set_op_layout_attr_anchor(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
+func.func @set_op_layout_attr_anchor_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_b
+func.func @set_op_layout_attr_anchor_dpas_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [16, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_c
+func.func @set_op_layout_attr_anchor_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 2 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 2b11acb04ed5b..e3e1313cf5f81 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -97,15 +97,17 @@ def setOpLayoutAttrOperandMinimal():
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
+ operand=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttr
# CHECK: transform.xegpu.set_op_layout_attr %
- # NO-CHECK: index = 0
- # NO-CHECK: result
+ # CHECK: operand
+ # CHECK-NOT: index = 0
+ # CHECK-NOT: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
- # NO-CHECK: inst_data
+ # CHECK-NOT: inst_data
@run
@@ -127,8 +129,8 @@ def setOpLayoutAttrResult():
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResult
# CHECK: transform.xegpu.set_op_layout_attr %
- # NO-CHECK: index = 0
# CHECK: result
+ # CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@@ -154,14 +156,40 @@ def setOpLayoutAttrResultSlice():
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
- # NO-CHECK: index = 0
# CHECK: result
+ # CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
+ at run
+def setOpLayoutAttrAnchor():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrAnchor
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # CHECK-NOT: result
+ # CHECK-NOT: operand
+ # CHECK-NOT: index = 0
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+
+
@run
def setGPULaunchThreadsOp():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list