[Mlir-commits] [mlir] [MLIR][Math][XeVM] Add MathToXeVM (`math-to-xevm`) pass (PR #159878)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Oct 10 04:07:25 PDT 2025


================
@@ -0,0 +1,178 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+                           PatternBenefit benefit = 1)
+      : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+      return failure();
+
+    arith::FastMathFlags fastFlags = op.getFastmath();
+    if (!(static_cast<uint32_t>(fastFlags) &
+          static_cast<uint32_t>(arith::FastMathFlags::afn)))
+      return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
+
+    SmallVector<Type, 1> operandTypes;
+    for (auto operand : adaptor.getOperands()) {
+      Type opTy = operand.getType();
+      // This pass only supports operations on vectors that are already in SPIRV
+      // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+      // supported vector sizes are done in other blocking optimization passes.
+      if (!isSPIRVCompatibleFloatOrVec(opTy))
+        return rewriter.notifyMatchFailure(
+            op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+      operandTypes.push_back(opTy);
+    }
+
+    auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+    auto funcOpRes = LLVM::lookupOrCreateFn(
+        rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+        operandTypes, op.getType());
+    assert(!failed(funcOpRes));
+    LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
+    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        op, funcOp, adaptor.getOperands());
+    // Preserve fastmath flags in our MLIR op when converting to llvm function
+    // calls, in order to allow further fastmath optimizations: We thus need to
+    // convert arith fastmath attrs into attrs recognized by llvm.
+    arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+    mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+    callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+    return success();
+  }
+
+  inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+    if (type.isFloat()) {
+      return true;
+    } else if (auto vecType = dyn_cast<VectorType>(type)) {
----------------
adam-smnk wrote:

nit: else could be skipped to reduce nesting

https://github.com/llvm/llvm-project/pull/159878


More information about the Mlir-commits mailing list