[Mlir-commits] [mlir] [MLIR][Conversion] Add convert-xevm-to-llvm pass. (PR #147375)
Sang Ik Lee
llvmlistbot at llvm.org
Wed Jul 9 10:57:15 PDT 2025
================
@@ -0,0 +1,616 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+ bool isConvergent = false;
+ bool isNoUnwind = false;
+ bool isWillReturn = false;
+ LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+ false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+ false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+ true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+ return TypeSwitch<Type, std::string>(ty)
+ .Case([isUnsigned](VectorType ty) -> std::string {
+ return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+ getTypeMangling(ty.getElementType(), isUnsigned);
+ })
+ .Case([](Float16Type) -> std::string { return "Dh"; })
+ .Case([](Float32Type) -> std::string { return "f"; })
+ .Case([](Float64Type) -> std::string { return "d"; })
+ .Case([isUnsigned](IntegerType ty) -> std::string {
+ switch (ty.getWidth()) {
+ case 8:
+ return isUnsigned ? "h" : "c";
+ case 16:
+ return isUnsigned ? "t" : "s";
+ case 32:
+ return isUnsigned ? "j" : "i";
+ case 64:
+ return isUnsigned ? "m" : "l";
+ default:
+ llvm_unreachable("unhandled integer type");
+ }
+ })
+ .Default([](Type) -> std::string {
+ llvm_unreachable("unhandled type for mangling");
+ });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+ ArrayRef<bool> isUnsigned = {}) {
+ assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+ "Signedness info doesn't match");
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ llvm::SmallDenseMap<Type, unsigned> substitutions;
+ os << "_Z" << baseName.size() << baseName;
+ for (auto [idx, type] : llvm::enumerate(types)) {
+ auto it = substitutions.find(type);
+ if (it != substitutions.end()) {
+ os << "S";
+ // First substitution is `S_`, second is `S0_`, and so on.
+ if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+ os << firstIdx - 1;
+ os << "_";
+ } else {
+ if (!type.isIntOrFloat())
+ substitutions[type] = substitutions.size();
+ os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+ }
+ }
+ return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 3;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 3;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+ int32_t control = 0;
+ if constexpr (isLoad) {
+ switch (*op.getCacheControl()) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3C:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ default:
+ break;
+ }
+ } else {
+ switch (*op.getCacheControl()) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 2;
+ break;
+ default:
+ break;
+ }
+ }
+ return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+ if (!op.getCacheControl())
+ return {};
+ constexpr int32_t decorationCacheControlArity{4};
+ constexpr int32_t loadCacheControlKey{6442};
+ constexpr int32_t storeCacheControlKey{6443};
+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+ controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+ SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+ controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+ auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+ auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+ SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+ return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
+ mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+ LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
+ auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+ assert(moduleOp && "Expecting module");
+ Location loc = op->getLoc();
+
+ auto funcOpRes =
+ LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+ funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ funcOp.setConvergent(funcAttributeOptions.isConvergent);
+ funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+ funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+ if (funcAttributeOptions.memEffectsAttr)
+ funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+ for (auto [idx, attrName] : paramAttrs)
+ funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+ auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ callOp->setAttrs(funcOp->getAttrs());
+
+ return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.getC()) {
+ return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+ }
+ constexpr uint32_t bitWidthPackedA{16};
+ constexpr uint32_t bitWidthPackedB{32};
+ auto loc = op.getLoc();
+
+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+ VectorType origTy = cast<VectorType>(val.getType());
+ const uint32_t vecBitSize =
+ origTy.getNumElements() *
+ origTy.getElementType().getIntOrFloatBitWidth();
+ VectorType newTy = VectorType::get(
+ vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+ if (origTy != newTy)
+ val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ return val;
+ };
+
+ Value a = op.getA();
+ Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedA);
----------------
silee2 wrote:
OpenCL intrinsics are defined that way.
For `A` and `B` types other than `tf32`, they are passed as integers as shown here:
https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
But if `A` and `B` are `tf32`, they are passed as `f32` or `float` in C language.
Example of `tf32` usage for intel graphics compiler here:
https://github.com/intel/intel-graphics-compiler/blob/master/IGC/ocloc_tests/Builtins/cl_intel_subgroup_matrix_multiply_accumulate_tf32/dpas.ll
https://github.com/llvm/llvm-project/pull/147375
More information about the Mlir-commits
mailing list