[Mlir-commits] [mlir] [MLIR][Dialect] Add XeVM dialect (PR #144811)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jul 1 06:39:19 PDT 2025


================
@@ -0,0 +1,377 @@
+//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc>
+#include <mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc>
+
+namespace {
+constexpr uint32_t subgroupSize = 16;
+
+template <typename Op>
+LogicalResult verifyMatrixInput(Op op) {
+  static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
+                                BlockPrefetch2dOp>::value,
+                "Unexpected template parameter");
+
+  std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
+  std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
+  if (pitch && width && *pitch < *width)
+    return op->emitOpError(
+        "4th operand (base pitch) should be >= 2nd operand (base width)");
+
+  uint32_t elemSize = op.getElemSizeInBits();
+  if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
+    return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
+
+  uint32_t tileHeight = op.getTileHeight();
+  if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
+    return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
+    return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
+
+  return success();
+}
+
+LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
+  VectorType resTy = op.getRes().getType();
+  if (!resTy.getElementType().isIntOrFloat())
+    return op.emitOpError()
+           << "expecting result element type to be int or float";
+  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+  unsigned resSize = resTy.getNumElements() * resElemTySize;
+  unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
+                          op.getTileWidth() * op.getVBlocks() / subgroupSize;
+  if (resSize != expectedSize)
+    return op.emitOpError() << "result size of " << resSize
+                            << " bits does not match the expected size of "
+                            << expectedSize << " bits";
+
+  if (op.getTranspose() && op.getPackRegister())
+    return op.emitOpError("transpose and pack_register are mutually exclusive");
+
+  if (!op.getTranspose() && !op.getPackRegister()) {
+    uint32_t tileHeight = op.getTileHeight();
+    if (tileHeight < 1 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 1 and 32");
+
+    uint32_t tileWidth = op.getTileWidth();
+    uint32_t vBlocks = op.getVBlocks();
+    switch (op.getElemSizeInBits()) {
+    case 8:
+      if (tileWidth < 4 || tileWidth > 64)
+        return op.emitOpError("expecting tile_width to be between 4 and 64");
+      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+        return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+      if (tileWidth * vBlocks > 64)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 64 for 8 bit elements");
+      break;
+    case 16:
+      if (tileWidth < 2 || tileWidth > 32)
+        return op.emitOpError("expecting tile_width to be between 2 and 32");
+      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+        return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+      if (tileWidth * vBlocks > 32)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 32 for 16 bit elements");
+      break;
+    case 32:
+      if (tileWidth < 1 || tileWidth > 16)
+        return op.emitOpError("expecting tile_width to be between 1 and 16");
+      if (vBlocks != 1 && vBlocks != 2)
+        return op.emitOpError("expecting v_blocks to be 1 or 2");
+      if (tileWidth * vBlocks > 16)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 16 for 32 bit elements");
+      break;
+    case 64:
+      if (tileWidth < 1 || tileWidth > 8)
+        return op.emitOpError("expecting tile_width to be between 1 and 8");
+      if (vBlocks != 1)
+        return op.emitOpError("expecting v_blocks to be 1");
+      break;
+    default:
+      return op.emitOpError(
+          "expecting elem_size_in_bits to be 8, 16, 32, or 64");
+    }
+
+    return success();
+  }
+
+  if (op.getTranspose()) {
+    assert(!op.getPackRegister() && "Expecting pack_register should be false");
+
+    uint32_t vBlocks = op.getVBlocks();
+    if (vBlocks != 1)
+      return op.emitOpError("expecting v_blocks to be 1");
+
+    uint32_t tileHeight = op.getTileHeight();
+    uint32_t tileWidth = op.getTileWidth();
+    switch (op.getElemSizeInBits()) {
+    case 32:
+      if (tileHeight < 1 || tileHeight > 32)
+        return op.emitOpError("expecting tile_height to be between 1 and 32");
+      if (tileWidth < 1 || tileWidth > 8)
+        return op.emitOpError("expecting tile_width to be between 1 and 8");
+      break;
+    case 64:
+      if (tileHeight != 8)
+        return op.emitOpError(
+            "expecting tile_height to be 8 for 64 bit elements");
+      if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
+        return op.emitOpError("expecting tile_width to be 1, 2, or 4");
+      break;
+    default:
+      return op.emitOpError("transpose is only supported for 32 and 64 bit "
+                            "elements");
+    }
+
+    return success();
+  }
+
+  assert(op.getPackRegister() && !op.getTranspose() &&
+         "Expecting pack_register should be true and transpose should be "
+         "false");
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+    return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+
+  uint32_t tileHeight = op.getTileHeight();
+  uint32_t tileWidth = op.getTileWidth();
+  switch (op.getElemSizeInBits()) {
+  case 8:
+    if (tileHeight < 4 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 4 and 32");
+    if (tileWidth < 4 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 4 and 16");
+    break;
+  case 16:
+    if (tileHeight < 2 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 2 and 32");
+    if (tileWidth < 2 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 2 and 16");
+    if (tileWidth * vBlocks > 32)
+      return op.emitOpError(
+          "tile_width * v_blocks should be less than or equal "
+          "to 32 for 16 bit elements");
+    break;
+  default:
+    return op.emitOpError("pack_register is only supported for 8 and 16 bit "
+                          "elements");
+  }
+
+  return success();
+}
+
+static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
+  uint32_t tileHeight = op.getTileHeight();
+  if (tileHeight < 1 || tileHeight > 8)
+    return op.emitOpError("expecting tile_height to be between 1 and 8");
+
+  uint32_t tileWidth = op.getTileWidth();
+  switch (op.getElemSizeInBits()) {
+  case 8:
+    if (tileWidth < 4 || tileWidth > 64)
+      return op.emitOpError("expecting tile_width to be between 4 and 64");
+    break;
+  case 16:
+    if (tileWidth < 2 || tileWidth > 32)
+      return op.emitOpError("expecting tile_width to be between 2 and 32");
+    break;
+  case 32:
+    if (tileWidth < 1 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 1 and 16");
+    break;
+  case 64:
+    if (tileWidth < 1 || tileWidth > 8)
+      return op.emitOpError("expecting tile_width to be between 1 and 8");
+    break;
+  default:
+    return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
+  }
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks != 1)
+    return op.emitOpError("expecting v_blocks to be 1");
+  return success();
+}
+
+} // namespace
+
+LogicalResult BlockLoad2dOp::verify() {
+  if (verify2DBlockLoadRestriction(*this).failed())
+    return failure();
+
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  VectorType resTy = getRes().getType();
+  if (!resTy.getElementType().isIntOrFloat())
+    return emitOpError() << "expecting result element type to be int of float";
+  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+  if (getElemSizeInBits() == 32 || getPackRegister()) {
+    if (resElemTySize != 32)
+      return emitOpError() << "expecting result element type to be 32 bits";
+  }
+
+  uint32_t tileWidth = getTileWidth();
+  if (getPackRegister()) {
+    if (tileWidth != 16)
+      return emitOpError(
+          "tile_width when pack_register is true should be equal "
+          "to subgroup size (16 elements)");
+    return success();
+  }
+
+  return success();
+}
+
+LogicalResult BlockStore2dOp::verify() {
+  if (verify2DBlockStoreRestriction(*this).failed())
+    return failure();
+
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  uint32_t tileWidth = getTileWidth();
+  switch (getElemSizeInBits()) {
+  case 8:
+    if (tileWidth != 16 && tileWidth != 32)
+      return emitOpError("tile_width for 8 bit elements should be equal to "
+                         "16 or 32");
+    break;
+  case 16:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 16 bit elements should be equal "
+                         "to 16");
+    break;
+  case 32:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 32 bit elements should be equal "
+                         "to 16");
+    break;
+  default:
+    llvm_unreachable("unexpected element size");
+  }
+
+  return success();
+}
+
+LogicalResult BlockPrefetch2dOp::verify() {
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  uint32_t tileWidth = getTileWidth();
+  switch (getElemSizeInBits()) {
+  case 8:
+    if (tileWidth != 16 && tileWidth != 32)
+      return emitOpError("tile_width for 8 bit elements should be equal to "
+                         "16 or 32");
+    break;
+  case 16:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 16 bit elements should be equal "
+                         "to 16");
+    break;
+  case 32:
+    if (tileWidth != 8 && tileWidth != 16)
+      return emitOpError(
+          "tile_width for 32 bit elements should be equal to 8 or 16");
+    break;
+  default:
+    llvm_unreachable("unexpected element size");
+  }
+
+  return success();
+}
+
+LogicalResult MMAOp::verify() {
+  if (getC()) {
+    if (getResult().getType() != getC().getType())
+      return emitOpError("type of C operand must match result type");
+  }
+  return success();
+}
+
+LogicalResult PrefetchOp::verify() {
+  auto ptrTy = mlir::dyn_cast<LLVM::LLVMPointerType>(getOperand().getType());
+  auto addrSpace = ptrTy.getAddressSpace();
+  if (addrSpace != 1 && addrSpace != 4)
+    return emitOpError(
+        "LLVM pointer type address space must be 1 (global) or 4 (generic)");
----------------
adam-smnk wrote:

This constrain could live directly in .td

https://github.com/llvm/llvm-project/pull/144811


More information about the Mlir-commits mailing list