[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add insert_prefetch op (PR #167356)

Tuomas Kärnä llvmlistbot at llvm.org
Tue Nov 11 10:30:27 PST 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/167356

>From 7dc5a277715a3f817541e91e815907461ece6d1f Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 10:59:51 +0200
Subject: [PATCH 1/2] [mlir][xegpu][transformops] add insert_prefetch op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  43 ++++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 138 ++++++++++++++++++
 mlir/python/mlir/dialects/transform/xegpu.py  |  33 +++++
 .../Dialect/XeGPU/transform-ops-invalid.mlir  |  31 ++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    |  76 ++++++++++
 .../python/dialects/transform_xegpu_ext.py    |  67 ++++++++-
 6 files changed, 387 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index f5e4afad535e5..85ad91f94a379 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -200,4 +200,47 @@ def SetGPULaunchThreadsOp
   }];
 }
 
+def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
+  let description = [{
+    Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
+    transform finds the corresponding `xegpu.load_nd` op and inserts
+    `xegpu.prefetch` operations for the tile. The load op must reside within the
+    `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
+    argument. Returns a handle to the created `xegpu.create_nd_desc` op.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$target,
+                   Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
+                   DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
+                   );
+
+  let results = (outs TransformHandleTypeInterface:$desc_op);
+
+  let assemblyFormat = [{
+    $target
+    `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    OpFoldResult getNbPrefetch() {
+      auto cxt = getContext();
+      if (getDynamicNbPrefetch())
+        return OpFoldResult(getDynamicNbPrefetch());
+      return OpFoldResult(IntegerAttr::get(
+                          IntegerType::get(cxt, 64), getStaticNbPrefetch()));
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7a7a8c9066f09..230b4aaaa8e8e 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 
@@ -405,6 +406,143 @@ void transform::SetGPULaunchThreadsOp::getEffects(
   modifiesPayload(effects);
 }
 
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+                                   transform::TransformResults &results,
+                                   transform::TransformState &state) {
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+  auto value = *targetValues.begin();
+
+  int64_t nbPrefetch = getStaticNbPrefetch();
+  if (getDynamicNbPrefetch()) {
+    // Get dynamic prefetch count from transform param or handle.
+    SmallVector<int32_t> dynamicNbPrefetch;
+    auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+                                          {getDynamicNbPrefetch()});
+    if (!status.succeeded())
+      return status;
+    if (dynamicNbPrefetch.size() != 1) {
+      return emitDefiniteFailure()
+             << "requires exactly one value for dynamic_nb_prefetch";
+    }
+    nbPrefetch = dynamicNbPrefetch[0];
+  }
+  if (nbPrefetch <= 0) {
+    return emitSilenceableFailure(getLoc())
+           << "nb_prefetch must be a positive integer.";
+  }
+
+  // Find load operation of the operand.
+  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+  if (!maybeLoadOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+  }
+  auto loadOp = *maybeLoadOp;
+  if (loadOp.getMixedOffsets().size() == 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op must have offsets.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find the parent scf.for loop.
+  auto forOp = loadOp->getParentOfType<scf::ForOp>();
+  if (!forOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op is not contained in a scf.for loop.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find descriptor op.
+  auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+  auto descOp = *maybeDescOp;
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "desc op with offsets is not supported.";
+    diag.attachNote(descOp.getLoc()) << "desc op";
+  }
+
+  // Clone desc op outside the loop.
+  rewriter.setInsertionPoint(forOp);
+  auto newDescOp =
+      cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+  // Clone reduction loop to emit initial prefetches.
+  // Compute upper bound of the init loop: start + nbPrefetch * step.
+  auto nbPrefetchCst =
+      arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+  auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+  auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+      forOp.getLoc(), forOp.getLowerBound(), nbStep);
+  auto initForOp =
+      scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+                         initUpBound, forOp.getStep());
+
+  auto ctx = rewriter.getContext();
+  auto readCacheHint =
+      xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+  // Modify loadOp mixedOffsets by replacing the for loop induction variable
+  // with the given value.
+  auto getPrefetchOffsets =
+      [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+    IRMapping mapping;
+    mapping.map(forOp.getInductionVar(), replacementVal);
+    SmallVector<Value> dynamicOffsets =
+        llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+          return mapping.lookupOrDefault(v);
+        }));
+    auto constOffsets = loadOp.getConstOffsets().value();
+    return getMixedValues(constOffsets, dynamicOffsets, ctx);
+  };
+
+  // Insert prefetch op in init loop.
+  // Replace induction var with the init loop induction var.
+  rewriter.setInsertionPointToStart(initForOp.getBody());
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(initForOp.getInductionVar()),
+                              readCacheHint, readCacheHint, readCacheHint);
+
+  // Insert prefetch op in main loop.
+  // Calculate prefetch offset after the init prefetches have been issued.
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+                                              forOp.getInductionVar(), nbStep);
+  // Replace induction var with correct offset.
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(prefetchOffset), readCacheHint,
+                              readCacheHint, readCacheHint);
+
+  // Unroll the init loop.
+  if (failed(loopUnrollFull(initForOp))) {
+    return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+  }
+
+  results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 309883cfc4518..6443d2a188ec1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -11,6 +11,7 @@
     from .._ods_common import _cext as _ods_cext
     from .._ods_common import (
         MixedValues,
+        MixedInt,
         get_op_result_or_value as _get_op_result_or_value,
         _dispatch_dynamic_index_list,
     )
@@ -134,6 +135,7 @@ def __init__(
         )
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
 class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
     """Specialization for SetGPULaunchThreadsOp class."""
 
@@ -168,3 +170,34 @@ def set_gpu_launch_threads(
     ip=None,
 ) -> SetGPULaunchThreadsOp:
     return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InsertPrefetchOp(InsertPrefetchOp):
+    """Specialization for InsertPrefetchOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        *,
+        nb_prefetch: Optional[MixedInt] = 1,
+        loc=None,
+        ip=None,
+    ):
+        static_nb_prefetch = 1
+        dynamic_nb_prefetch = None
+        if isinstance(nb_prefetch, int):
+            static_nb_prefetch = nb_prefetch
+        elif isinstance(nb_prefetch, IntegerAttr):
+            static_nb_prefetch = nb_prefetch.value  # pytype: disable=attribute-error
+        elif isinstance(nb_prefetch, (Operation, Value, OpView)):
+            dynamic_nb_prefetch = nb_prefetch
+
+        super().__init__(
+            transform.AnyOpType.get(),
+            target,
+            dynamic_nb_prefetch=dynamic_nb_prefetch,
+            static_nb_prefetch=static_nb_prefetch,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 24f500658f740..dce4a41982550 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -124,3 +124,34 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_c
+func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  // expected-note at below {{load op}}
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+    // expected-error at below {{Load op is not contained in a scf.for loop.}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 7f2fbe4271a43..aed8874723801 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -308,3 +308,79 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a
+func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: xegpu.create_nd_tdesc %arg0
+  // CHECK: xegpu.create_nd_tdesc %arg1
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  // CHECK: scf.for
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    // CHECK: xegpu.prefetch_nd %[[V0]]
+    %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
+func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: xegpu.create_nd_tdesc %arg0
+  // CHECK: xegpu.create_nd_tdesc %arg1
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  // CHECK: scf.for
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    // CHECK: xegpu.prefetch_nd %[[V0]]
+    %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
+    // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb :  (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index dc91f5e982579..cfe2281ba5eff 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import AnyValueType
+from mlir.dialects.transform import structured, AnyValueType
 
 
 def run(f):
@@ -128,3 +128,68 @@ def setGPULaunchThreadsOp():
     # CHECK-LABEL: TEST: setGPULaunchThreadsOp
     # CHECK: transform.xegpu.set_gpu_launch_threads
     # CHECK: threads = [8, 4, 1]
+
+
+ at run
+def insertPrefetch0():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.InsertPrefetchOp(
+            operand,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetch0
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+
+
+ at run
+def insertPrefetchNbPrefetch():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.InsertPrefetchOp(
+            operand,
+            nb_prefetch=2,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetchNbPrefetch
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+    # CHECK-SAME: nb_prefetch = 2
+
+
+ at run
+def insertPrefetchNbPrefetchParam():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        int32_t = IntegerType.get_signless(32)
+        param_int32_t = transform.ParamType.get(int32_t)
+        nb_param = transform.ParamConstantOp(
+            param_int32_t,
+            IntegerAttr.get(int32_t, 2),
+        )
+        xegpu.InsertPrefetchOp(
+            operand,
+            nb_prefetch=nb_param,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+    # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]

>From 85aafcc78ce2b242f2756271c9fd6cc2b784125d Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 19:03:58 +0200
Subject: [PATCH 2/2] address review comments

---
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 18 ++++------
 mlir/python/mlir/dialects/transform/xegpu.py  | 10 ++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 34 ++++++++++++++-----
 .../python/dialects/transform_xegpu_ext.py    |  6 ++--
 4 files changed, 44 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 230b4aaaa8e8e..d2235f18ceaec 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -411,11 +411,10 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
                                    transform::TransformResults &results,
                                    transform::TransformState &state) {
   auto targetValues = state.getPayloadValues(getTarget());
-  if (!llvm::hasSingleElement(targetValues)) {
+  if (!llvm::hasSingleElement(targetValues))
     return emitDefiniteFailure()
            << "requires exactly one target value handle (got "
            << llvm::range_size(targetValues) << ")";
-  }
   auto value = *targetValues.begin();
 
   int64_t nbPrefetch = getStaticNbPrefetch();
@@ -426,22 +425,19 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
                                           {getDynamicNbPrefetch()});
     if (!status.succeeded())
       return status;
-    if (dynamicNbPrefetch.size() != 1) {
+    if (dynamicNbPrefetch.size() != 1)
       return emitDefiniteFailure()
              << "requires exactly one value for dynamic_nb_prefetch";
-    }
     nbPrefetch = dynamicNbPrefetch[0];
   }
-  if (nbPrefetch <= 0) {
+  if (nbPrefetch <= 0)
     return emitSilenceableFailure(getLoc())
            << "nb_prefetch must be a positive integer.";
-  }
 
   // Find load operation of the operand.
   auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
-  if (!maybeLoadOp) {
+  if (!maybeLoadOp)
     return emitSilenceableFailure(getLoc()) << "Could not find load op.";
-  }
   auto loadOp = *maybeLoadOp;
   if (loadOp.getMixedOffsets().size() == 0) {
     auto diag = emitSilenceableFailure(getLoc())
@@ -461,9 +457,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
 
   // Find descriptor op.
   auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
-  if (!maybeDescOp) {
+  if (!maybeDescOp)
     return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
-  }
   auto descOp = *maybeDescOp;
   if (descOp.getMixedOffsets().size() > 0) {
     auto diag = emitSilenceableFailure(getLoc())
@@ -526,9 +521,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
                               readCacheHint, readCacheHint);
 
   // Unroll the init loop.
-  if (failed(loopUnrollFull(initForOp))) {
+  if (failed(loopUnrollFull(initForOp)))
     return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
-  }
 
   results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
 
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 6443d2a188ec1..aa3cea58623ea 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -201,3 +201,13 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+def insert_prefetch(
+    target: Value,
+    *,
+    nb_prefetch: Optional[MixedInt] = 1,
+    loc=None,
+    ip=None,
+) -> OpResult:
+    return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index aed8874723801..b3b883826c1c8 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -313,8 +313,10 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: @insert_prefetch_dpas_a
 func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  // CHECK: %[[C32:.+]] = arith.constant 32 : index
   %c32 = arith.constant 32 : index
   %c4096 = arith.constant 4096 : index
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
   %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
@@ -322,12 +324,13 @@ func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<40
   // CHECK: xegpu.create_nd_tdesc %arg1
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
   // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
-  // CHECK: xegpu.prefetch_nd %[[V0]]
+  // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]]
   %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
-  // CHECK: scf.for
+  // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
   %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
-    // CHECK: xegpu.prefetch_nd %[[V0]]
+    // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]]
+    // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]]
     %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
     %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
     %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
@@ -338,10 +341,15 @@ func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<40
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     // CHECK: transform.xegpu.insert_prefetch %{{.*}}
     %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+
     transform.yield
   }
 }
@@ -350,8 +358,11 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
 func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  // CHECK: %[[C64:.+]] = arith.constant 64 : index
+  // CHECK: %[[C32:.+]] = arith.constant 32 : index
   %c32 = arith.constant 32 : index
   %c4096 = arith.constant 4096 : index
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
   %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
@@ -359,13 +370,14 @@ func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1:
   // CHECK: xegpu.create_nd_tdesc %arg1
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
   // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
-  // CHECK: xegpu.prefetch_nd %[[V0]]
-  // CHECK: xegpu.prefetch_nd %[[V0]]
+  // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]]
+  // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]]
   %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
-  // CHECK: scf.for
+  // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
   %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
-    // CHECK: xegpu.prefetch_nd %[[V0]]
+    // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]]
+    // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]]
     %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
     %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
     %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
@@ -376,11 +388,15 @@ func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1:
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
     // CHECK: transform.xegpu.insert_prefetch %{{.*}}
     %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb :  (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index cfe2281ba5eff..56c7d71f28431 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -139,7 +139,7 @@ def insertPrefetch0():
     )
     with InsertionPoint(sequence.body):
         operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
-        xegpu.InsertPrefetchOp(
+        xegpu.insert_prefetch(
             operand,
         )
         transform.YieldOp()
@@ -157,7 +157,7 @@ def insertPrefetchNbPrefetch():
     )
     with InsertionPoint(sequence.body):
         operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
-        xegpu.InsertPrefetchOp(
+        xegpu.insert_prefetch(
             operand,
             nb_prefetch=2,
         )
@@ -183,7 +183,7 @@ def insertPrefetchNbPrefetchParam():
             param_int32_t,
             IntegerAttr.get(int32_t, 2),
         )
-        xegpu.InsertPrefetchOp(
+        xegpu.insert_prefetch(
             operand,
             nb_prefetch=nb_param,
         )



More information about the Mlir-commits mailing list