[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add set_op_layout_attr op (PR #166854)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 6 13:58:43 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tuomas Kärnä (tkarna)
<details>
<summary>Changes</summary>
Adds `transform.xegpu.set_op_layout_attr` transform op that attaches `xegpu.layout` attribute to the target op.
Also adds `getLayoutAttrFromOperands` utility function.
For reference, the rationale behind xegpu transform ops is outlined in [this RFC document](https://github.com/tkarna/llvm-project/blob/xegpu-transform-ops-doc/mlir/docs/XeGPUTransformOps.md).
---
Patch is 23.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166854.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+65)
- (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+106-19)
- (modified) mlir/python/mlir/dialects/transform/xegpu.py (+47)
- (modified) mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir (+58)
- (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+134)
- (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+49)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Set xegpu.layout attribute of an op.";
+ let description = [{
+ Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+ `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+ target operand/result value is defined by the `index` argument. The layout
+ is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface : $target,
+ DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedAttr<UnitAttr, "false">:$result
+ );
+
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "int64_t":$index,
+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
+ "ArrayRef<OpFoldResult>":$mixedSgData,
+ "ArrayRef<OpFoldResult>":$mixedInstData,
+ CArg<"bool", "false">:$result
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ $target (`result` $result^)? (`index` `=` $index^)?
+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+ (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgData(), getSgData(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticInstData(), getInstData(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+ transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();
- SmallVector<int32_t> sgLayout;
- DiagnosedSilenceableFailure status =
- convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;
- SmallVector<int32_t> sgData;
- status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
- if (!status.succeeded())
- return status;
-
- SmallVector<int32_t> instData;
- status =
- convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
- if (!status.succeeded())
- return status;
- auto maybeInstData = instData.empty()
- ? std::nullopt
- : std::optional<ArrayRef<int32_t>>(instData);
-
// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attr in desc op's return type. Replaces old desc op.
- auto layoutAttr =
- createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
// Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget) {
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+ } else {
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+ """Specialization for SetOpLayoutAttrOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ sg_layout: MixedValues,
+ sg_data: MixedValues,
+ *,
+ inst_data: MixedValues = None,
+ index: Union[int, Attribute] = None,
+ result: Union[bool, Attribute] = None,
+ loc=None,
+ ip=None,
+ ):
+ inst_data = [] if inst_data is None else inst_data
+ (
+ dynamic_sg_layout,
+ static_sg_layout,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_layout)
+ (
+ dynamic_sg_data,
+ static_sg_data,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_data)
+ (
+ dynamic_inst_data,
+ static_inst_data,
+ _,
+ ) = _dispatch_dynamic_index_list(inst_data)
+ super().__init__(
+ _get_op_result_or_value(target),
+ dynamic_sg_layout,
+ dynamic_sg_data,
+ dynamic_inst_data,
+ static_sg_layout=static_sg_layout,
+ static_sg_data=static_sg_data,
+ static_inst_data=static_inst_data,
+ index=index,
+ result=result,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Index exceeds the number of op results}}
+ transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Index exceeds the number of op operands}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+ %2 = arith.extf %1 : vector<256x32...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/166854
More information about the Mlir-commits
mailing list