[Mlir-commits] [mlir] 48566b2 - [MLIR][XeGPU][TransformOps] set_op_layout_attr supports setting anchor layout (#172542)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 12 21:50:05 PST 2026


Author: Tuomas Kärnä
Date: 2026-02-13T07:49:59+02:00
New Revision: 48566b21a48573e0b2dcc67936ea8fdb7c40974a

URL: https://github.com/llvm/llvm-project/commit/48566b21a48573e0b2dcc67936ea8fdb7c40974a
DIFF: https://github.com/llvm/llvm-project/commit/48566b21a48573e0b2dcc67936ea8fdb7c40974a.diff

LOG: [MLIR][XeGPU][TransformOps] set_op_layout_attr supports setting anchor layout (#172542)

Changes `transform.xegpu.set_op_layout_attr` to support xegpu anchor
layouts. By default, if `result` and `operand` bool arguments are unset,
this transform op sets the op's anchor layout, if the op supports it
(otherwise emits a silenceable failure).

In contrast to the earlier implementation, setting the operand layout
now requires setting the new `operand` argument.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
    mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
    mlir/python/mlir/dialects/transform/xegpu.py
    mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
    mlir/test/Dialect/XeGPU/transform-ops.mlir
    mlir/test/python/dialects/transform_xegpu_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 29579acc727ed..23dabe4eb380a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -109,12 +109,14 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
 
   let summary = "Set xegpu.layout attribute of an op.";
   let description = [{
-    Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
-    `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
-    target operand/result value is defined by the `index` argument. The layout
-    is defined by the `sg_layout`, `sg_data` and optional `inst_data`
-    attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
-    wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
+    Sets the `xegpu.layout` attribute of an op. By default it sets the anchor
+    layout for XeGPU ops that support it. If `result=true` or `operand=true`,
+    it sets the `layout_result_{index}` or `layout_operand_{index}` attribute,
+    respectively, applicable to any op. The target operand/result value is
+    defined by the `index` argument. The layout is defined by the `sg_layout`,
+    `sg_data` and optional `inst_data` attributes. If `slice_dims` is provided,
+    the `xegpu.layout` attribute is wrapped in an `xegpu.slice<..., dims=slice_dims>`
+    attribute.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
@@ -126,7 +128,8 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
-                   DefaultValuedAttr<UnitAttr, "false">:$result
+                   DefaultValuedAttr<UnitAttr, "false">:$result,
+                   DefaultValuedAttr<UnitAttr, "false">:$operand
                    );
 
   let results = (outs);
@@ -137,12 +140,13 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    "ArrayRef<OpFoldResult>":$mixedSgData,
                    "ArrayRef<OpFoldResult>":$mixedInstData,
                    "ArrayRef<int64_t>":$sliceDims,
-                   CArg<"bool", "false">:$result
+                   CArg<"bool", "false">:$result,
+                   CArg<"bool", "false">:$operand
                    )>,
   ];
 
   let assemblyFormat = [{
-    $target (`result` $result^)? (`index` `=` $index^)?
+    $target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
@@ -169,6 +173,8 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
       return getMixedValues(getStaticInstData(), getInstData(), b);
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 def SetGPULaunchThreadsOp

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 285dfac56be08..7bc67da8263dc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -291,7 +291,7 @@ void transform::SetOpLayoutAttrOp::build(
     OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
     ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
     ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
-    bool result) {
+    bool result, bool operand) {
   SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
   SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
   dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -307,7 +307,8 @@ void transform::SetOpLayoutAttrOp::build(
         /*static_sg_data=*/staticSgData,
         /*static_inst_data=*/staticInstData,
         /*slice_dims=*/sliceDims,
-        /*result=*/result);
+        /*result=*/result,
+        /*operand=*/operand);
 }
 
 DiagnosedSilenceableFailure
@@ -322,13 +323,14 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
   Operation *target = *targetOps.begin();
 
   bool resultTarget = getResult();
+  bool operandTarget = getOperand();
 
   int64_t index = getIndex();
   if (resultTarget && index >= target->getNumResults()) {
     return emitSilenceableFailure(getLoc())
            << "Index exceeds the number of op results";
   }
-  if (!resultTarget && index >= target->getNumOperands()) {
+  if (operandTarget && index >= target->getNumOperands()) {
     return emitSilenceableFailure(getLoc())
            << "Index exceeds the number of op operands";
   }
@@ -348,11 +350,38 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
         getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
   }
 
-  // Set layout attribute for the op result or operand
-  if (resultTarget)
+  // Set layout attribute
+  if (resultTarget) {
+    // op result
     xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
-  else
+  } else if (operandTarget) {
+    // op operand
     xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+  } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
+    // dpas op is a special case where layout needs to be set for A, B, and C
+    if (index == 0)
+      dpasOp.getProperties().layout_a = layout;
+    else if (index == 1)
+      dpasOp.getProperties().layout_b = layout;
+    else if (index == 2)
+      dpasOp.getProperties().layout_cd = layout;
+    else {
+      auto diag = emitSilenceableFailure(getLoc())
+                  << "Invalid index for setting dpas op layout: " << index;
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+  } else {
+    // op's anchor layout.
+    auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
+    if (!anchorOp) {
+      auto diag = emitSilenceableFailure(getLoc())
+                  << "Cannot set anchor layout to op: " << target->getName();
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    anchorOp.setAnchorLayout(layout);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -365,6 +394,13 @@ void transform::SetOpLayoutAttrOp::getEffects(
   modifiesPayload(effects);
 }
 
+LogicalResult transform::SetOpLayoutAttrOp::verify() {
+  if (getResult() && getOperand()) {
+    return emitOpError("Cannot set both result and operand simultaneously.");
+  }
+  return success();
+}
+
 void transform::SetGPULaunchThreadsOp::build(
     OpBuilder &builder, OperationState &ostate, Value target,
     ArrayRef<OpFoldResult> mixedThreads) {

diff  --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 5aa6453b7cb8a..a768ce5f4e720 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -134,6 +134,7 @@ def __init__(
         slice_dims: Optional[MixedInt] = None,
         index: Optional[Union[int, Attribute]] = None,
         result: Optional[Union[bool, Attribute]] = None,
+        operand: Optional[Union[bool, Attribute]] = None,
         loc=None,
         ip=None,
     ):
@@ -164,6 +165,7 @@ def __init__(
             slice_dims=slice_dims,
             index=index,
             result=result,
+            operand=operand,
             loc=loc,
             ip=ip,
         )
@@ -178,6 +180,7 @@ def set_op_layout_attr(
     slice_dims: Optional[MixedInt] = None,
     index: Optional[Union[int, Attribute]] = None,
     result: Optional[Union[bool, Attribute]] = None,
+    operand: Optional[Union[bool, Attribute]] = None,
     loc=None,
     ip=None,
 ) -> SetOpLayoutAttrOp:
@@ -189,6 +192,7 @@ def set_op_layout_attr(
         slice_dims=slice_dims,
         index=index,
         result=result,
+        operand=operand,
         loc=loc,
         ip=ip,
     )

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index dce4a41982550..2a147497a893b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -47,7 +47,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error at below {{Index exceeds the number of op operands}}
-    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
     transform.yield
   }
 }
@@ -67,6 +67,25 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_not_anchor_op
+func.func @set_op_layout_attr_not_anchor_op(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Cannot set anchor layout to op: arith.extf}}
     transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
     transform.yield
   }

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 13ed24ebf0a3a..9a278cbf7b498 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -140,23 +140,19 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @set_op_layout_attr_result_default_index
-func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_op_layout_attr_result_default
+func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
-  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
-  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
-  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
-  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
-  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
-  // CHECK: = xegpu.dpas
-  // CHECK-SAME: {layout_cd = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
-  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
   return
 }
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
     transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
     transform.yield
@@ -210,27 +206,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @set_op_layout_attr_result0
-func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
-  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
-  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
-  // CHECK: = arith.extf %1
-  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
-  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
-    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 // CHECK-LABEL: @set_op_layout_attr_result_slice
 func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
   // CHECK: = arith.extf
@@ -264,7 +239,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
-    transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
     transform.yield
   }
 }
@@ -287,7 +262,102 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
-    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor
+func.func @set_op_layout_attr_anchor(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  // CHECK: = xegpu.load_nd %0[0, 0]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
+func.func @set_op_layout_attr_anchor_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_b
+func.func @set_op_layout_attr_anchor_dpas_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [16, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_c
+func.func @set_op_layout_attr_anchor_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 2 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
     transform.yield
   }
 }

diff  --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 2b11acb04ed5b..e3e1313cf5f81 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -97,15 +97,17 @@ def setOpLayoutAttrOperandMinimal():
             sequence.bodyTarget,
             sg_layout=[6, 4],
             sg_data=[32, 16],
+            operand=True,
         )
         transform.YieldOp()
     # CHECK-LABEL: TEST: setOpLayoutAttr
     # CHECK: transform.xegpu.set_op_layout_attr %
-    # NO-CHECK: index = 0
-    # NO-CHECK: result
+    # CHECK: operand
+    # CHECK-NOT: index = 0
+    # CHECK-NOT: result
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
-    # NO-CHECK: inst_data
+    # CHECK-NOT: inst_data
 
 
 @run
@@ -127,8 +129,8 @@ def setOpLayoutAttrResult():
         transform.YieldOp()
     # CHECK-LABEL: TEST: setOpLayoutAttrResult
     # CHECK: transform.xegpu.set_op_layout_attr %
-    # NO-CHECK: index = 0
     # CHECK: result
+    # CHECK-NOT: index = 0
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
@@ -154,14 +156,40 @@ def setOpLayoutAttrResultSlice():
         transform.YieldOp()
     # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
     # CHECK: transform.xegpu.set_op_layout_attr %
-    # NO-CHECK: index = 0
     # CHECK: result
+    # CHECK-NOT: index = 0
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
     # CHECK: slice_dims = [0]
 
 
+ at run
+def setOpLayoutAttrAnchor():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.set_op_layout_attr(
+            sequence.bodyTarget,
+            index=0,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setOpLayoutAttrAnchor
+    # CHECK: transform.xegpu.set_op_layout_attr %
+    # CHECK-NOT: result
+    # CHECK-NOT: operand
+    # CHECK-NOT: index = 0
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]
+
+
 @run
 def setGPULaunchThreadsOp():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list