[Mlir-commits] [mlir] [MLIR][Conversion] Add convert-xevm-to-llvm pass. (PR #147375)

Sang Ik Lee llvmlistbot at llvm.org
Wed Jul 9 10:57:15 PDT 2025


================
@@ -0,0 +1,616 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+  bool isConvergent = false;
+  bool isNoUnwind = false;
+  bool isWillReturn = false;
+  LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+    false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+    false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+    true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+  return TypeSwitch<Type, std::string>(ty)
+      .Case([isUnsigned](VectorType ty) -> std::string {
+        return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+               getTypeMangling(ty.getElementType(), isUnsigned);
+      })
+      .Case([](Float16Type) -> std::string { return "Dh"; })
+      .Case([](Float32Type) -> std::string { return "f"; })
+      .Case([](Float64Type) -> std::string { return "d"; })
+      .Case([isUnsigned](IntegerType ty) -> std::string {
+        switch (ty.getWidth()) {
+        case 8:
+          return isUnsigned ? "h" : "c";
+        case 16:
+          return isUnsigned ? "t" : "s";
+        case 32:
+          return isUnsigned ? "j" : "i";
+        case 64:
+          return isUnsigned ? "m" : "l";
+        default:
+          llvm_unreachable("unhandled integer type");
+        }
+      })
+      .Default([](Type) -> std::string {
+        llvm_unreachable("unhandled type for mangling");
+      });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+                   ArrayRef<bool> isUnsigned = {}) {
+  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+         "Signedness info doesn't match");
+  std::string s;
+  llvm::raw_string_ostream os(s);
+  llvm::SmallDenseMap<Type, unsigned> substitutions;
+  os << "_Z" << baseName.size() << baseName;
+  for (auto [idx, type] : llvm::enumerate(types)) {
+    auto it = substitutions.find(type);
+    if (it != substitutions.end()) {
+      os << "S";
+      // First substitution is `S_`, second is `S0_`, and so on.
+      if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+        os << firstIdx - 1;
+      os << "_";
+    } else {
+      if (!type.isIntOrFloat())
+        substitutions[type] = substitutions.size();
+      os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+    }
+  }
+  return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3UC:
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3UC:
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+    case LoadCacheControl::L1S_L2UC_L3C:
+    case LoadCacheControl::L1S_L2C_L3UC:
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 3;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3UC:
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 3;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+    case LoadCacheControl::L1UC_L2C_L3UC:
+    case LoadCacheControl::L1C_L2UC_L3UC:
+    case LoadCacheControl::L1C_L2C_L3UC:
+    case LoadCacheControl::L1S_L2UC_L3UC:
+    case LoadCacheControl::L1S_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3C:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3C:
+    case LoadCacheControl::L1S_L2UC_L3C:
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+    case StoreCacheControl::L1S_L2UC_L3UC:
+    case StoreCacheControl::L1S_L2WB_L3UC:
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3WB:
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 2;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+  if (!op.getCacheControl())
+    return {};
+  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t loadCacheControlKey{6442};
+  constexpr int32_t storeCacheControlKey{6443};
+  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+      controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+      controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+  return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+    ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+    ArrayRef<Type> argTypes, ArrayRef<Value> args,
+    mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+    LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
+  auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+  assert(moduleOp && "Expecting module");
+  Location loc = op->getLoc();
+
+  auto funcOpRes =
+      LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+  assert(!failed(funcOpRes));
+  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+  funcOp.setConvergent(funcAttributeOptions.isConvergent);
+  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+  if (funcAttributeOptions.memEffectsAttr)
+    funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+  for (auto [idx, attrName] : paramAttrs)
+    funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+  auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+  callOp->setAttrs(funcOp->getAttrs());
+
+  return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getC()) {
+      return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+    }
+    constexpr uint32_t bitWidthPackedA{16};
+    constexpr uint32_t bitWidthPackedB{32};
+    auto loc = op.getLoc();
+
+    auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+      VectorType origTy = cast<VectorType>(val.getType());
+      const uint32_t vecBitSize =
+          origTy.getNumElements() *
+          origTy.getElementType().getIntOrFloatBitWidth();
+      VectorType newTy = VectorType::get(
+          vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+      if (origTy != newTy)
+        val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+      return val;
+    };
+
+    Value a = op.getA();
+    Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedA);
----------------
silee2 wrote:

OpenCL intrinsics are defined that way.
For `A` and `B` types other than `tf32`, they are passed as integers as shown here:
 https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
But if `A` and `B` are `tf32`, they are passed as `f32` or `float` in C language.
Example of `tf32` usage for intel graphics compiler here:
https://github.com/intel/intel-graphics-compiler/blob/master/IGC/ocloc_tests/Builtins/cl_intel_subgroup_matrix_multiply_accumulate_tf32/dpas.ll

https://github.com/llvm/llvm-project/pull/147375


More information about the Mlir-commits mailing list