[Mlir-commits] [mlir] 7cb57c6 - [MLIR][XeGPU][TransformOps] Remove obsolete transform ops (#187561)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 25 01:24:15 PDT 2026
Author: Tuomas Kärnä
Date: 2026-03-25T10:24:11+02:00
New Revision: 7cb57c680828ff1a8cb020296c9995cc01bf1c3d
URL: https://github.com/llvm/llvm-project/commit/7cb57c680828ff1a8cb020296c9995cc01bf1c3d
DIFF: https://github.com/llvm/llvm-project/commit/7cb57c680828ff1a8cb020296c9995cc01bf1c3d.diff
LOG: [MLIR][XeGPU][TransformOps] Remove obsolete transform ops (#187561)
Cleaning up XeGPU transform ops. Now that XeGPU layout propagation
works, it is sufficient to set the layouts for anchor ops (e.g.
load/store/dpas ops) only.
Changes:
* Remove `xegpu.get_desc_op` and `xegpu.set_desc_layout`. Users should
not change the layout of descriptor op's return value anymore.
* Add `xegpu.get_load_op(value)` that finds either `xegpu.load_nd` or
`xegpu.load` op in the value's producer chain. This is a useful utility
as load ops often need to be annotated with a layout.
* The generic `xegpu.set_op_layout_attr(op, ...)` is now replaced by
`xegpu.set_anchor_layout(op, ...)` that only sets layout attribute of
anchor ops. Raises an error if the given op does not support anchor
layouts.
* `xegpu.insert_prefetch` takes a load op handle instead of a value.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index f7f45508b6a03..40b9136874e7c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,110 +16,39 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
-def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
+def GetLoadOp : Op<Transform_Dialect, "xegpu.get_load_op", [
DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface
]> {
- let summary = "Get a handle to the descriptor op of a value.";
+ let summary = "Get a handle to the load_nd op in producer chain of a value.";
let description = [{
- Traces the producers of the given value until an `xegpu.create_nd_tdesc`
- descriptor op is found. Returns a handle to it. Currently traces
+ Traces the producers of the given value until an `xegpu.load_nd` or
+ `xegpu.load` op is found. Returns a handle to it. Currently traces
producers by following only the first operand of producer ops.
}];
let arguments = (ins TransformValueHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$descHandle);
+ let results = (outs TransformHandleTypeInterface:$loadNdHandle);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
-def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+def SetAnchorLayoutOp : Op<Transform_Dialect, "xegpu.set_anchor_layout", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface
]> {
- let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
+ let summary = "Set anchor layout of an op.";
let description = [{
- Given an `xegpu.create_nd_desc` operation, this transform adds
- `xegpu.layout` attribute to the result tensor descriptor. The layout is
- defined by the `sg_layout`, and `sg_data` and optional `inst_data`
- attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
- wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to
- the transformed op.
- }];
-
- let arguments = (ins
- TransformHandleTypeInterface:$target,
- Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
- Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
- Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
- DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
- );
-
- let results = (outs TransformHandleTypeInterface:$transformed);
- let builders = [
- OpBuilder<(ins "Value":$target,
- "ArrayRef<OpFoldResult>":$mixedSgLayout,
- "ArrayRef<OpFoldResult>":$mixedSgData,
- "ArrayRef<OpFoldResult>":$mixedInstData,
- "ArrayRef<int32_t>":$order,
- "ArrayRef<int64_t>":$sliceDims
- )>,
- ];
-
- let assemblyFormat = [{
- $target
- `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
- `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
- (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
- (`order` `=` $order^)?
- (`slice_dims` `=` $slice_dims^)?
- attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::transform::TransformResults &transformResults,
- ::mlir::transform::TransformState &state);
-
- ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
- Builder b(getContext());
- return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
- }
- ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
- Builder b(getContext());
- return getMixedValues(getStaticSgData(), getSgData(), b);
- }
- ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
- Builder b(getContext());
- return getMixedValues(getStaticInstData(), getInstData(), b);
- }
- }];
-}
-
-def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
- AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface
-]> {
-
- let summary = "Set xegpu.layout attribute of an op.";
- let description = [{
- Sets the `xegpu.layout` attribute of an op. By default it sets the anchor
- layout for XeGPU ops that support it. If `result=true` or `operand=true`,
- it sets the `layout_result_{index}` or `layout_operand_{index}` attribute,
- respectively, applicable to any op. The target operand/result value is
- defined by the `index` argument. The layout is defined by the `sg_layout`,
- `sg_data` and optional `inst_data` attributes. If `slice_dims` is provided,
- the `xegpu.layout` attribute is wrapped in an `xegpu.slice<..., dims=slice_dims>`
- attribute.
+ Sets the `xegpu.layout` anchor layout for XeGPU ops that support it. The
+ target operand value can be set by the `index` argument (currently only
+ applicable to a DPAS op). The layout is defined by the `sg_layout`,
+ `sg_data` and optional `inst_data` and `order` attributes. If `slice_dims`
+ is provided, the `xegpu.layout` attribute is wrapped in an
+ `xegpu.slice<..., dims=slice_dims>` attribute. Emits a silenceable failure
+ if the target op does not support anchor layouts.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -131,9 +60,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
- DefaultValuedAttr<UnitAttr, "false">:$result,
- DefaultValuedAttr<UnitAttr, "false">:$operand
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
);
let results = (outs);
@@ -144,14 +71,12 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int32_t>":$order,
- "ArrayRef<int64_t>":$sliceDims,
- CArg<"bool", "false">:$result,
- CArg<"bool", "false">:$operand
+ "ArrayRef<int64_t>":$sliceDims
)>,
];
let assemblyFormat = [{
- $target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
+ $target (`index` `=` $index^)?
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
@@ -179,8 +104,6 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
return getMixedValues(getStaticInstData(), getInstData(), b);
}
}];
-
- let hasVerifier = 1;
}
def SetGPULaunchThreadsOp
@@ -227,17 +150,15 @@ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
TransformOpInterface
]> {
- let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
+ let summary = "Adds xegpu prefetch ops to a load op.";
let description = [{
- Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
- transform finds the corresponding `xegpu.load_nd` op and inserts
- `xegpu.prefetch_nd` operations for the tile. The load op must reside within
- the `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
- argument (default value is 1). Returns a handle to the created
- `xegpu.create_nd_desc` op.
+ Inserts `xegpu.prefetch_nd` operations for the given `xegpu.load_nd` op.
+ The load op must reside within the `scf.for` loop. Number of prefetch steps
+ is set by the `nb_prefetch` argument (default value is 1). Returns a handle
+ to the created `xegpu.create_nd_desc` op.
}];
- let arguments = (ins TransformValueHandleTypeInterface:$target,
+ let arguments = (ins TransformHandleTypeInterface:$target,
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
);
@@ -275,9 +196,9 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
let summary = "Convert xegpu.layout attribute for a value.";
let description = [{
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
- of a value. The input and target layouts are defined by the `*sg_layout`,
- `*sg_data` and optional `*inst_data` attributes. Returns a handle to the
- emitted `xegpu.convert_layout` op.
+ of a value before its first use. The input and target layouts are defined
+ by the `*sg_layout`, `*sg_data` and optional `*inst_data` and `*order`
+ attributes. Returns a handle to the emitted `xegpu.convert_layout` op.
}];
let arguments = (ins TransformValueHandleTypeInterface:$target,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 39f9ae0bf1287..153ef5b500a1b 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -165,30 +165,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
-/// Replace xegpu.create_nd_desc op with a new one with the given layout.
-static xegpu::CreateNdDescOp
-setDescLayout(transform::TransformRewriter &rewriter,
- xegpu::CreateNdDescOp descOp,
- xegpu::DistributeLayoutAttr layout) {
- assert(descOp.getMixedOffsets().size() == 0 &&
- "create desc op with offsets is not supported");
- auto oldTensorDesc = descOp.getType();
- auto descType = xegpu::TensorDescType::get(
- oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
- /*array_length=*/oldTensorDesc.getArrayLength(),
- /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
- /*memory_space=*/oldTensorDesc.getMemorySpace(),
- /*layout=*/layout);
-
- rewriter.setInsertionPointAfter(descOp);
- auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
- descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
- descOp.getMixedStrides());
- return newDescOp;
-}
-
DiagnosedSilenceableFailure
-transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+transform::GetLoadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetValues = state.getPayloadValues(getTarget());
@@ -198,102 +176,33 @@ transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
<< llvm::range_size(targetValues) << ")";
}
- auto maybeDescOp =
- findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
- if (!maybeDescOp) {
- return emitSilenceableFailure(getLoc())
- << "Could not find a matching descriptor op when walking the "
- "producer chain of the first operand.";
- }
-
- results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform::SetDescLayoutOp::build(OpBuilder &builder,
- OperationState &result, Value target,
- ArrayRef<OpFoldResult> mixedSgLayout,
- ArrayRef<OpFoldResult> mixedSgData,
- ArrayRef<OpFoldResult> mixedInstData,
- ArrayRef<int32_t> order,
- ArrayRef<int64_t> sliceDims) {
- SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
- SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
- dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
- dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
- dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
- build(builder, result, target.getType(),
- /*target=*/target,
- /*sg_layout=*/dynamicSgLayout,
- /*sg_data=*/dynamicSgData,
- /*inst_data=*/dynamicInstData,
- /*static_sg_layout=*/staticSgLayout,
- /*static_sg_data=*/staticSgData,
- /*static_inst_data=*/staticInstData,
- /*order=*/order,
- /*slice_dims=*/sliceDims);
-}
-
-DiagnosedSilenceableFailure
-transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
- auto targetOps = state.getPayloadOps(getTarget());
- if (!llvm::hasSingleElement(targetOps)) {
- return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
- << llvm::range_size(targetOps) << ")";
- }
- Operation *target = *targetOps.begin();
-
- xegpu::LayoutAttr layoutAttr = nullptr;
- auto status = getLayoutAttrFromOperands(
- getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
- getMixedInstData(), getOrder(), layoutAttr);
- if (!status.succeeded())
- return status;
-
- xegpu::DistributeLayoutAttr layout = layoutAttr;
- auto sliceDims = getSliceDims();
- if (sliceDims.size() > 0) {
- // Wrap layoutAttr in a slice attribute.
- layout = xegpu::SliceAttr::get(
- getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
- }
-
- // For now only create_nd_desc op is supported.
- auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
- if (!descOp) {
- auto diag = emitSilenceableFailure(getLoc())
- << "Expected a xegpu.create_nd_desc op, but got: "
- << target->getName();
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+ Operation *loadOp = nullptr;
+ auto maybeLoadNdOp =
+ findProducerOfType<xegpu::LoadNdOp>(*targetValues.begin());
+ if (maybeLoadNdOp) {
+ loadOp = maybeLoadNdOp->getOperation();
+ } else {
+ auto maybeLoadOp =
+ findProducerOfType<xegpu::LoadGatherOp>(*targetValues.begin());
+ if (maybeLoadOp) {
+ loadOp = maybeLoadOp->getOperation();
+ } else {
+ return emitSilenceableFailure(getLoc())
+ << "Could not find a matching xegpu.load_nd or xegpu.load op when "
+ "walking the "
+ "producer chain of the first operand.";
+ }
}
- // Set layout attr in desc op's return type. Replaces old desc op.
- auto newdescOp = setDescLayout(rewriter, descOp, layout);
-
- // Map result handles.
- results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
-
+ results.set(llvm::cast<OpResult>(getResult()), {loadOp});
return DiagnosedSilenceableFailure::success();
}
-void transform::SetDescLayoutOp::getEffects(
- ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTargetMutable(), effects);
- onlyReadsHandle(getSgLayoutMutable(), effects);
- onlyReadsHandle(getSgDataMutable(), effects);
- onlyReadsHandle(getInstDataMutable(), effects);
- producesHandle(getOperation()->getOpResults(), effects);
- modifiesPayload(effects);
-}
-
-void transform::SetOpLayoutAttrOp::build(
+void transform::SetAnchorLayoutOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
- ArrayRef<int64_t> sliceDims, bool result, bool operand) {
+ ArrayRef<int64_t> sliceDims) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -309,35 +218,17 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*order=*/order,
- /*slice_dims=*/sliceDims,
- /*result=*/result,
- /*operand=*/operand);
+ /*slice_dims=*/sliceDims);
}
DiagnosedSilenceableFailure
-transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+transform::SetAnchorLayoutOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
- if (!llvm::hasSingleElement(targetOps)) {
- return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
- << llvm::range_size(targetOps) << ")";
- }
- Operation *target = *targetOps.begin();
-
- bool resultTarget = getResult();
- bool operandTarget = getOperand();
-
int64_t index = getIndex();
- if (resultTarget && index >= target->getNumResults()) {
- return emitSilenceableFailure(getLoc())
- << "Index exceeds the number of op results";
- }
- if (operandTarget && index >= target->getNumOperands()) {
- return emitSilenceableFailure(getLoc())
- << "Index exceeds the number of op operands";
- }
+ // Construct layout attribute.
xegpu::LayoutAttr layoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(
getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
@@ -353,42 +244,39 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
}
- // Set layout attribute
- if (resultTarget) {
- // op result
- xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
- } else if (operandTarget) {
- // op operand
- xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
- } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
- // dpas op is a special case where layout needs to be set for A, B, and C
- if (index == 0)
- dpasOp.getProperties().layout_a = layout;
- else if (index == 1)
- dpasOp.getProperties().layout_b = layout;
- else if (index == 2)
- dpasOp.getProperties().layout_cd = layout;
- else {
- auto diag = emitSilenceableFailure(getLoc())
- << "Invalid index for setting dpas op layout: " << index;
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
- } else {
- // op's anchor layout.
- auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
- if (!anchorOp) {
- auto diag = emitSilenceableFailure(getLoc())
- << "Cannot set anchor layout to op: " << target->getName();
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+ // Apply the layout to all target ops.
+ for (Operation *target : targetOps) {
+ // Set layout attribute
+ if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
+ // dpas op is a special case where layout needs to be set for A, B, and C
+ if (index == 0)
+ dpasOp.getProperties().layout_a = layout;
+ else if (index == 1)
+ dpasOp.getProperties().layout_b = layout;
+ else if (index == 2)
+ dpasOp.getProperties().layout_cd = layout;
+ else {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Invalid index for setting dpas op layout: " << index;
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ } else {
+ // op's anchor layout.
+ auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
+ if (!anchorOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Cannot set anchor layout to op: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ anchorOp.setAnchorLayout(layout);
}
- anchorOp.setAnchorLayout(layout);
}
return DiagnosedSilenceableFailure::success();
}
-void transform::SetOpLayoutAttrOp::getEffects(
+void transform::SetAnchorLayoutOp::getEffects(
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsHandle(getSgLayoutMutable(), effects);
@@ -397,13 +285,6 @@ void transform::SetOpLayoutAttrOp::getEffects(
modifiesPayload(effects);
}
-LogicalResult transform::SetOpLayoutAttrOp::verify() {
- if (getResult() && getOperand()) {
- return emitOpError("Cannot set both result and operand simultaneously.");
- }
- return success();
-}
-
void transform::SetGPULaunchThreadsOp::build(
OpBuilder &builder, OperationState &ostate, Value target,
ArrayRef<OpFoldResult> mixedThreads) {
@@ -471,12 +352,12 @@ DiagnosedSilenceableFailure
transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- auto targetValues = state.getPayloadValues(getTarget());
- if (!llvm::hasSingleElement(targetValues))
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps))
return emitDefiniteFailure()
- << "requires exactly one target value handle (got "
- << llvm::range_size(targetValues) << ")";
- auto value = *targetValues.begin();
+ << "requires exactly one target op handle (got "
+ << llvm::range_size(targetOps) << ")";
+ auto target = *targetOps.begin();
int64_t nbPrefetch = getStaticNbPrefetch();
if (getDynamicNbPrefetch()) {
@@ -495,11 +376,13 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
return emitSilenceableFailure(getLoc())
<< "nb_prefetch must be a positive integer.";
- // Find load operation of the operand.
- auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
- if (!maybeLoadOp)
- return emitSilenceableFailure(getLoc()) << "Could not find load op.";
- auto loadOp = *maybeLoadOp;
+ // Cast target to load op.
+ auto maybeLoadOp = dyn_cast<xegpu::LoadNdOp>(target);
+ if (!maybeLoadOp) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected xegpu.load_nd op, got " << target->getName();
+ }
+ auto loadOp = maybeLoadOp;
if (loadOp.getMixedOffsets().size() == 0) {
auto diag = emitSilenceableFailure(getLoc())
<< "Load op must have offsets.";
@@ -517,7 +400,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
}
// Find descriptor op.
- auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ auto maybeDescOp =
+ findProducerOfType<xegpu::CreateNdDescOp>(loadOp.getResult());
if (!maybeDescOp)
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
auto descOp = *maybeDescOp;
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 782c9a3f242a0..6e27e5c8ecfa6 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -23,8 +23,8 @@
@_ods_cext.register_operation(_Dialect, replace=True)
-class GetDescOp(GetDescOp):
- """Specialization for GetDescOp class."""
+class GetLoadOp(GetLoadOp):
+ """Specialization for GetLoadOp class."""
def __init__(
self,
@@ -33,100 +33,27 @@ def __init__(
loc=None,
ip=None,
):
- desc_type = transform.AnyOpType.get()
+ load_nd_type = transform.AnyOpType.get()
super().__init__(
- desc_type,
+ load_nd_type,
target,
loc=loc,
ip=ip,
)
-def get_desc_op(
+def get_load_op(
target: Value,
*,
loc=None,
ip=None,
) -> OpResult:
- return GetDescOp(target, loc=loc, ip=ip).result
+ return GetLoadOp(target, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect, replace=True)
-class SetDescLayoutOp(SetDescLayoutOp):
- """Specialization for SetDescLayoutOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- sg_layout: MixedValues,
- sg_data: MixedValues,
- *,
- inst_data: Optional[MixedValues] = None,
- order: Optional[MixedInt] = None,
- slice_dims: Optional[MixedInt] = None,
- loc=None,
- ip=None,
- ):
- target_handle = _get_op_result_or_value(target)
- inst_data = [] if inst_data is None else inst_data
- (
- dynamic_sg_layout,
- static_sg_layout,
- _,
- ) = _dispatch_dynamic_index_list(sg_layout)
- (
- dynamic_sg_data,
- static_sg_data,
- _,
- ) = _dispatch_dynamic_index_list(sg_data)
- (
- dynamic_inst_data,
- static_inst_data,
- _,
- ) = _dispatch_dynamic_index_list(inst_data)
-
- super().__init__(
- target_handle.type,
- target_handle,
- dynamic_sg_layout,
- dynamic_sg_data,
- dynamic_inst_data,
- static_sg_layout=static_sg_layout,
- static_sg_data=static_sg_data,
- static_inst_data=static_inst_data,
- order=order,
- slice_dims=slice_dims,
- loc=loc,
- ip=ip,
- )
-
-
-def set_desc_layout(
- target: Union[Operation, Value],
- sg_layout: MixedValues,
- sg_data: MixedValues,
- *,
- inst_data: Optional[MixedValues] = None,
- order: Optional[MixedInt] = None,
- slice_dims: Optional[MixedInt] = None,
- loc=None,
- ip=None,
-) -> OpResult:
- return SetDescLayoutOp(
- target,
- sg_layout,
- sg_data,
- inst_data=inst_data,
- order=order,
- slice_dims=slice_dims,
- loc=loc,
- ip=ip,
- ).result
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
- """Specialization for SetOpLayoutAttrOp class."""
+class SetAnchorLayoutOp(SetAnchorLayoutOp):
+ """Specialization for SetAnchorLayoutOp class."""
def __init__(
self,
@@ -138,8 +65,6 @@ def __init__(
order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
- result: Optional[Union[bool, Attribute]] = None,
- operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
@@ -170,14 +95,12 @@ def __init__(
order=order,
slice_dims=slice_dims,
index=index,
- result=result,
- operand=operand,
loc=loc,
ip=ip,
)
-def set_op_layout_attr(
+def set_anchor_layout(
target: Union[Operation, Value],
sg_layout: MixedValues,
sg_data: MixedValues,
@@ -186,12 +109,10 @@ def set_op_layout_attr(
order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
- result: Optional[Union[bool, Attribute]] = None,
- operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
-) -> SetOpLayoutAttrOp:
- return SetOpLayoutAttrOp(
+) -> SetAnchorLayoutOp:
+ return SetAnchorLayoutOp(
target,
sg_layout,
sg_data,
@@ -199,8 +120,6 @@ def set_op_layout_attr(
order=order,
slice_dims=slice_dims,
index=index,
- result=result,
- operand=operand,
loc=loc,
ip=ip,
)
@@ -249,7 +168,7 @@ class InsertPrefetchOp(InsertPrefetchOp):
def __init__(
self,
- target: Value,
+ target: Union[Operation, Value],
*,
nb_prefetch: Optional[MixedInt] = 1,
loc=None,
@@ -275,7 +194,7 @@ def __init__(
def insert_prefetch(
- target: Value,
+ target: Union[Operation, Value],
*,
nb_prefetch: Optional[MixedInt] = 1,
loc=None,
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 2a147497a893b..ba259f311d76e 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -1,81 +1,7 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
-func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
- %c32 = arith.constant 32 : index // expected-note {{target op}}
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error at below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_bad_result_index
-func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error at below {{Index exceeds the number of op results}}
- transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
-func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error at below {{Index exceeds the number of op operands}}
- transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_multiple
-func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
- %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
- transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_not_anchor_op
-func.func @set_op_layout_attr_not_anchor_op(%arg0: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_not_anchor_op
+func.func @set_anchor_layout_not_anchor_op(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> // expected-note {{target op}}
@@ -86,7 +12,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error at below {{Cannot set anchor layout to op: arith.extf}}
- transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.xegpu.set_anchor_layout %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
@@ -169,8 +95,9 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
// expected-error at below {{Load op is not contained in a scf.for loop.}}
- %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ %3 = transform.xegpu.insert_prefetch %2 nb_prefetch = 1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 5bb1ab708e301..acba80d870253 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,43 +1,35 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
-// CHECK-LABEL: @get_desc_op_a
-func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
- %c32 = arith.constant 32 : index
- %c4096 = arith.constant 4096 : index
+// CHECK-LABEL: @get_load_op
+func.func @get_load_op(%arg0: memref<4096x4096xf16>) {
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
- %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
- // expected-remark @below {{found desc op}}
- %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
- %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
- %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
- %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
- scf.yield %7 : vector<256x256xf16>
- }
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: xegpu.load_nd
+ // expected-remark @below {{found load_nd op}}
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
- %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
- transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found load_nd op" : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @get_desc_op_c
-func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @get_load_op_c
+func.func @get_load_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
- // expected-remark @below {{found desc op}}
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ // expected-remark @below {{found load_nd op}}
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
@@ -54,282 +46,150 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
- %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
- transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_desc_layout
-func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
- // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
- // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_desc_layout %{{.*}}
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found load_nd op" : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_desc_layout_minimal
-func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
- // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+// CHECK-LABEL: @get_load_op_1d
+func.func @get_load_op_1d(%arg0: memref<4096xf32>) {
+ %cst = arith.constant dense<true> : vector<256xi1>
+ %0 = vector.step : vector<256xindex>
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xf32> -> index
+ %1 = arith.index_cast %intptr : index to i64
+ // CHECK: xegpu.load %1[%0]
+ // expected-remark @below {{found load op}}
+ %2 = xegpu.load %1[%0], %cst : i64, vector<256xindex>, vector<256xi1> -> vector<256xf32>
+ %3 = arith.extf %2 : vector<256xf32> to vector<256xf64>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_desc_layout %{{.*}}
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found load op" : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_desc_layout_param
-func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
- // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+// CHECK-LABEL: @set_anchor_layout
+func.func @set_anchor_layout(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_desc_layout %{{.*}}
- %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_desc_layout_slice
-func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
- // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_desc_layout %{{.*}}
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_desc_layout_order
-func.func @set_desc_layout_order(%arg0: memref<4096x4096xf16>) {
- // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
- // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_desc_layout %{{.*}}
- %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_result_default
-func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_multiple
+func.func @set_anchor_layout_multiple(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.extf %1
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ // CHECK: xegpu.prefetch_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK: xegpu.prefetch_nd %0[16, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ xegpu.prefetch_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16>
+ xegpu.prefetch_nd %0[16, 0] : !xegpu.tensor_desc<256x32xf16>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.prefetch_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_result_sg_param
-func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_param
+func.func @set_anchor_layout_param(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.extf %1
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
- transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
-func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_param2
+func.func @set_anchor_layout_param2(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.extf %1
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
%layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
- transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_result_slice
-func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
- // CHECK: = arith.extf
- // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
- %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_result_order
-func.func @set_op_layout_attr_result_order(%arg0: vector<256xf16>) {
- // CHECK: = arith.extf
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [0, 1]>}
- %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+// CHECK-LABEL: @set_anchor_layout_slice
+func.func @set_anchor_layout_slice(%arg0: memref<4096xf32>) {
+ // CHECK: = xegpu.load %1[%0]
+ // CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>}>
+ %cst = arith.constant dense<true> : vector<256xi1>
+ %0 = vector.step : vector<256xindex>
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xf32> -> index
+ %1 = arith.index_cast %intptr : index to i64
+ %2 = xegpu.load %1[%0], %cst : i64, vector<256xindex>, vector<256xi1> -> vector<256xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [0, 1] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_operand_minimal
-func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.extf %1
- // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
- %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_operand1
-func.func @set_op_layout_attr_operand1(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- // CHECK: = arith.addf %1, %3
- // CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
- %6 = arith.addf %1, %3 : vector<256x32xf16>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @set_op_layout_attr_anchor
-func.func @set_op_layout_attr_anchor(%arg0: memref<4096x4096xf16>) {
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- // CHECK: = xegpu.load_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
- %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.load"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_anchor_order
-func.func @set_op_layout_attr_anchor_order(%arg0: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_order
+func.func @set_anchor_layout_order(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: = xegpu.load_nd %0[0, 0]
// CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
@@ -340,8 +200,8 @@ func.func @set_op_layout_attr_anchor_order(%arg0: memref<4096x4096xf16>) {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
transform.yield
}
}
@@ -349,8 +209,8 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
-func.func @set_op_layout_attr_anchor_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_dpas_a
+func.func @set_anchor_layout_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
@@ -366,16 +226,16 @@ func.func @set_op_layout_attr_anchor_dpas_a(%arg0: memref<4096x4096xf16>, %arg1:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_b
-func.func @set_op_layout_attr_anchor_dpas_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_dpas_b
+func.func @set_anchor_layout_dpas_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
@@ -391,16 +251,16 @@ func.func @set_op_layout_attr_anchor_dpas_b(%arg0: memref<4096x4096xf16>, %arg1:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [16, 16] : !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 1 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [16, 16] : !transform.any_op
transform.yield
}
}
// -----
-// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_c
-func.func @set_op_layout_attr_anchor_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @set_anchor_layout_dpas_c
+func.func @set_anchor_layout_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
@@ -416,8 +276,8 @@ func.func @set_op_layout_attr_anchor_dpas_c(%arg0: memref<4096x4096xf16>, %arg1:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 2 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
+ // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
+ transform.xegpu.set_anchor_layout %0 index = 2 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
@@ -512,8 +372,9 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
- %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ %3 = transform.xegpu.insert_prefetch %2 nb_prefetch = 1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op
@@ -559,9 +420,10 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_load_op %1 : (!transform.any_value) -> !transform.any_op
%nb = transform.param.constant 2 : i64 -> !transform.param<i64>
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
- %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+ %3 = transform.xegpu.insert_prefetch %2 nb_prefetch = %nb : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 346e68eca9201..5d5db1919af14 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -17,7 +17,7 @@ def run(f):
@run
-def getDescOpDefaultIndex():
+def getLoadOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
@@ -25,130 +25,29 @@ def getDescOpDefaultIndex():
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
- desc_handle = xegpu.get_desc_op(operand)
+ load_handle = xegpu.get_load_op(operand)
transform.YieldOp()
- # CHECK-LABEL: TEST: getDescOpDefaultIndex
- # CHECK: transform.xegpu.get_desc_op %
+ # CHECK-LABEL: TEST: getLoadOp
+ # CHECK: transform.xegpu.get_load_op %
@run
-def setDescLayoutMinimal():
+def setAnchorLayout():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.create_nd_tdesc"),
+ transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
- xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
- transform.YieldOp()
- # CHECK-LABEL: TEST: setDescLayoutMinimal
- # CHECK: %0 = transform.xegpu.set_desc_layout %
- # CHECK: sg_layout = [6, 4]
- # CHECK: sg_data = [32, 16]
-
-
- at run
-def setDescLayoutInstData():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("xegpu.create_nd_tdesc"),
- )
- with InsertionPoint(sequence.body):
- xegpu.set_desc_layout(
- sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
- )
- transform.YieldOp()
- # CHECK-LABEL: TEST: setDescLayoutInstData
- # CHECK: %0 = transform.xegpu.set_desc_layout %
- # CHECK: sg_layout = [6, 4]
- # CHECK: sg_data = [32, 16]
- # CHECK: inst_data = [8, 16]
-
-
- at run
-def setDescLayoutSlice():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("xegpu.create_nd_tdesc"),
- )
- with InsertionPoint(sequence.body):
- xegpu.set_desc_layout(
- sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
- )
- transform.YieldOp()
- # CHECK-LABEL: TEST: setDescLayoutSlice
- # CHECK: %0 = transform.xegpu.set_desc_layout %
- # CHECK: sg_layout = [6, 4]
- # CHECK: sg_data = [32, 16]
- # CHECK: slice_dims = [0]
-
-
- at run
-def setDescLayoutOrder():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("xegpu.create_nd_tdesc"),
- )
- with InsertionPoint(sequence.body):
- xegpu.set_desc_layout(
- sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], order=[0, 1]
- )
- transform.YieldOp()
- # CHECK-LABEL: TEST: setDescLayoutOrder
- # CHECK: %0 = transform.xegpu.set_desc_layout %
- # CHECK: sg_layout = [6, 4]
- # CHECK: sg_data = [32, 16]
- # CHECK: order = [0, 1]
-
-
- at run
-def setOpLayoutAttrOperandMinimal():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("xegpu.dpas"),
- )
- with InsertionPoint(sequence.body):
- xegpu.set_op_layout_attr(
+ xegpu.set_anchor_layout(
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
- operand=True,
- )
- transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttr
- # CHECK: transform.xegpu.set_op_layout_attr %
- # CHECK: operand
- # CHECK-NOT: index = 0
- # CHECK-NOT: result
- # CHECK: sg_layout = [6, 4]
- # CHECK: sg_data = [32, 16]
- # CHECK-NOT: inst_data
-
-
- at run
-def setOpLayoutAttrResult():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("xegpu.dpas"),
- )
- with InsertionPoint(sequence.body):
- xegpu.set_op_layout_attr(
- sequence.bodyTarget,
- index=0,
- sg_layout=[6, 4],
- sg_data=[32, 16],
inst_data=[8, 16],
- result=True,
)
transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttrResult
- # CHECK: transform.xegpu.set_op_layout_attr %
- # CHECK: result
+ # CHECK-LABEL: TEST: setAnchorLayout
+ # CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
@@ -156,85 +55,77 @@ def setOpLayoutAttrResult():
@run
-def setOpLayoutAttrResultSlice():
+def setAnchorLayoutDPAS():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
- xegpu.set_op_layout_attr(
+ xegpu.set_anchor_layout(
sequence.bodyTarget,
- index=0,
+ index=1,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
- slice_dims=[0],
- result=True,
)
transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
- # CHECK: transform.xegpu.set_op_layout_attr %
- # CHECK: result
- # CHECK-NOT: index = 0
+ # CHECK-LABEL: TEST: setAnchorLayoutDPAS
+ # CHECK: transform.xegpu.set_anchor_layout %
+ # CHECK: index = 1
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
- # CHECK: slice_dims = [0]
@run
-def setOpLayoutAttrResultOrder():
+def setAnchorLayoutOrder():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.dpas"),
+ transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
- xegpu.set_op_layout_attr(
+ xegpu.set_anchor_layout(
sequence.bodyTarget,
- index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
- order=[0, 1],
- result=True,
+ order=[1, 0],
)
transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttrResultOrder
- # CHECK: transform.xegpu.set_op_layout_attr %
- # CHECK: result
+ # CHECK-LABEL: TEST: setAnchorLayoutOrder
+ # CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
- # CHECK: order = [0, 1]
+ # CHECK: order = [1, 0]
@run
-def setOpLayoutAttrAnchor():
+def setAnchorLayoutSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.dpas"),
+ transform.OperationType.get("xegpu.load"),
)
with InsertionPoint(sequence.body):
- xegpu.set_op_layout_attr(
+ xegpu.set_anchor_layout(
sequence.bodyTarget,
- index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
+ slice_dims=[0],
)
transform.YieldOp()
- # CHECK-LABEL: TEST: setOpLayoutAttrAnchor
- # CHECK: transform.xegpu.set_op_layout_attr %
- # CHECK-NOT: result
- # CHECK-NOT: operand
+ # CHECK-LABEL: TEST: setAnchorLayoutSlice
+ # CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
+ # CHECK: slice_dims = [0]
@run
@@ -253,21 +144,17 @@ def setGPULaunchThreadsOp():
@run
-def insertPrefetch0():
+def insertPrefetch():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.dpas"),
+ transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
- operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
- xegpu.insert_prefetch(
- operand,
- )
+ xegpu.insert_prefetch(sequence.bodyTarget)
transform.YieldOp()
- # CHECK-LABEL: TEST: insertPrefetch0
- # CHECK: %[[OPR:.*]] = get_operand
- # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-LABEL: TEST: insertPrefetch
+ # CHECK: transform.xegpu.insert_prefetch
@run
@@ -275,18 +162,13 @@ def insertPrefetchNbPrefetch():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.dpas"),
+ transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
- operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
- xegpu.insert_prefetch(
- operand,
- nb_prefetch=2,
- )
+ xegpu.insert_prefetch(sequence.bodyTarget, nb_prefetch=2)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetch
- # CHECK: %[[OPR:.*]] = get_operand
- # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK: transform.xegpu.insert_prefetch
# CHECK-SAME: nb_prefetch = 2
@@ -295,25 +177,20 @@ def insertPrefetchNbPrefetchParam():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("xegpu.dpas"),
+ transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
- operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
int32_t = IntegerType.get_signless(32)
param_int32_t = transform.ParamType.get(int32_t)
nb_param = transform.ParamConstantOp(
param_int32_t,
IntegerAttr.get(int32_t, 2),
)
- xegpu.insert_prefetch(
- operand,
- nb_prefetch=nb_param,
- )
+ xegpu.insert_prefetch(sequence.bodyTarget, nb_prefetch=nb_param)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
- # CHECK: %[[OPR:.*]] = get_operand
# CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
- # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK: transform.xegpu.insert_prefetch
# CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
More information about the Mlir-commits
mailing list