[Mlir-commits] [mlir] [MLIR][XeGPU] add xegpu.set_desc_layout transform op (PR #165615)

Tuomas Kärnä llvmlistbot at llvm.org
Tue Nov 4 12:59:48 PST 2025


================
@@ -0,0 +1,238 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<gpu::GPUDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Attribute case.
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        result.push_back(intAttr.getInt());
+        continue;
+      }
+      return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+    }
+
+    // Transform param case.
+    Value transformValue = cast<Value>(ofr);
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    // Payload value case.
+    auto payloadOps = state.getPayloadOps(transformValue);
+    if (!llvm::hasSingleElement(payloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(transformValue.getLoc())
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+      return diag;
+    }
+
+    Operation *op = *payloadOps.begin();
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+
+    IntegerAttr intAttr;
+    if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+      return transformOp.emitSilenceableError()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    result.push_back(intAttr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                                   std::optional<ArrayRef<int32_t>> sgData,
+                                   std::optional<ArrayRef<int32_t>> instData) {
+  return xegpu::LayoutAttr::get(
+      ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+      sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
+      instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+      /*lane_layout=*/nullptr,
+      /*lane_data=*/nullptr,
+      /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+                                    xegpu::CreateNdDescOp descOp,
+                                    xegpu::LayoutAttr layout) {
+  auto oldTensorDesc = descOp.getResult();
+  auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+  auto descType = xegpu::TensorDescType::get(
+      descShapedType.getShape(), descShapedType.getElementType(),
+      /*array_length=*/1,
+      /*boundary_check=*/true,
+      /*memory_space=*/xegpu::MemorySpace::Global,
----------------
tkarna wrote:

Updated, BlockTensorDescAttrs are preserved now. This is now tested in the first test.

https://github.com/llvm/llvm-project/pull/165615


More information about the Mlir-commits mailing list