[Mlir-commits] [mlir] Revert "[MLIR][Math][XeVM] Add MathToXeVM (`math-to-xevm`) pass" (PR #162923)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Fri Oct 10 14:17:11 PDT 2025


https://github.com/mshahneo created https://github.com/llvm/llvm-project/pull/162923

Reverts llvm/llvm-project#159878

>From 5741e129230decf0a48971c06548fae5a8caeaa6 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <98356296+mshahneo at users.noreply.github.com>
Date: Fri, 10 Oct 2025 16:16:43 -0500
Subject: [PATCH] Revert "[MLIR][Math][XeVM] Add MathToXeVM (`math-to-xevm`)
 pass (#159878)"

This reverts commit fabd1c418a6b20266cf191b5d8c92476567c77af.
---
 .../mlir/Conversion/MathToXeVM/MathToXeVM.h   |  27 ---
 mlir/include/mlir/Conversion/Passes.h         |   1 -
 mlir/include/mlir/Conversion/Passes.td        |  25 ---
 mlir/lib/Conversion/CMakeLists.txt            |   1 -
 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt |  19 --
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 174 ------------------
 .../Conversion/MathToXeVM/math-to-xevm.mlir   | 155 ----------------
 .../MathToXeVM/native-spirv-builtins.mlir     | 118 ------------
 8 files changed, 520 deletions(-)
 delete mode 100644 mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
 delete mode 100644 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
 delete mode 100644 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
 delete mode 100644 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
 delete mode 100644 mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir

diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
deleted file mode 100644
index 91d3c92fd6296..0000000000000
--- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
-#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
-
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
-#include "mlir/IR/PatternMatch.h"
-#include <memory>
-
-namespace mlir {
-class Pass;
-
-#define GEN_PASS_DECL_CONVERTMATHTOXEVM
-#include "mlir/Conversion/Passes.h.inc"
-
-/// Populate the given list with patterns that convert from Math to XeVM calls.
-void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
-                                          bool convertArith);
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..da061b269daf7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,7 +49,6 @@
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 25e9d34f3e653..3c18ecc753d0f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,31 +796,6 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
-//===----------------------------------------------------------------------===//
-// MathToXeVM
-//===----------------------------------------------------------------------===//
-
-def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
-  let summary =
-      "Convert (fast) math operations to native XeVM/SPIRV equivalents";
-  let description = [{
-    This pass converts supported math ops marked with the `afn` fastmath flag
-    to function calls for OpenCL `native_` math intrinsics: These intrinsics
-    are typically mapped directly to native device instructions, often resulting
-    in better performance. However, the precision/error of these intrinsics
-    are implementation-defined, and thus math ops are only converted when they
-    have the `afn` fastmath flag enabled.
-  }];
-  let options = [Option<
-      "convertArith", "convert-arith", "bool", /*default=*/"true",
-      "Convert supported Arith ops (e.g. arith.divf) as well.">];
-  let dependentDialects = [
-    "arith::ArithDialect",
-    "xevm::XeVMDialect",
-    "LLVM::LLVMDialect",
-  ];
-}
-
 //===----------------------------------------------------------------------===//
 // MathToEmitC
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..71986f83c4870 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,7 +40,6 @@ add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
-add_subdirectory(MathToXeVM)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
deleted file mode 100644
index 95aaba31a993e..0000000000000
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRMathToXeVM
-  MathToXeVM.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRMathDialect
-  MLIRLLVMCommonConversion
-  MLIRPass
-  MLIRTransformUtils
-  MLIRVectorDialect
-  )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
deleted file mode 100644
index 03053dee5af40..0000000000000
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ /dev/null
@@ -1,174 +0,0 @@
-//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
-#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTMATHTOXEVM
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
-using namespace mlir;
-
-#define DEBUG_TYPE "math-to-xevm"
-
-/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
-template <typename Op>
-struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
-
-  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
-                           PatternBenefit benefit = 1)
-      : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
-
-  LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!isSPIRVCompatibleFloatOrVec(op.getType()))
-      return failure();
-
-    arith::FastMathFlags fastFlags = op.getFastmath();
-    if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
-      return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
-
-    SmallVector<Type, 1> operandTypes;
-    for (auto operand : adaptor.getOperands()) {
-      Type opTy = operand.getType();
-      // This pass only supports operations on vectors that are already in SPIRV
-      // supported vector sizes: Distributing unsupported vector sizes to SPIRV
-      // supported vector sizes are done in other blocking optimization passes.
-      if (!isSPIRVCompatibleFloatOrVec(opTy))
-        return rewriter.notifyMatchFailure(
-            op, llvm::formatv("incompatible operand type: '{0}'", opTy));
-      operandTypes.push_back(opTy);
-    }
-
-    auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
-    auto funcOpRes = LLVM::lookupOrCreateFn(
-        rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
-        operandTypes, op.getType());
-    assert(!failed(funcOpRes));
-    LLVM::LLVMFuncOp funcOp = funcOpRes.value();
-
-    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
-        op, funcOp, adaptor.getOperands());
-    // Preserve fastmath flags in our MLIR op when converting to llvm function
-    // calls, in order to allow further fastmath optimizations: We thus need to
-    // convert arith fastmath attrs into attrs recognized by llvm.
-    arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
-    mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
-    callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
-    return success();
-  }
-
-  inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
-    if (type.isFloat())
-      return true;
-    if (auto vecType = dyn_cast<VectorType>(type)) {
-      if (!vecType.getElementType().isFloat())
-        return false;
-      // SPIRV distinguishes between vectors and matrices: OpenCL native math
-      // intrsinics are not compatible with matrices.
-      ArrayRef<int64_t> shape = vecType.getShape();
-      if (shape.size() != 1)
-        return false;
-      // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
-      if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
-          shape[0] == 16)
-        return true;
-    }
-    return false;
-  }
-
-  inline std::string
-  getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
-    std::string mangledFuncName =
-        "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
-
-    auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
-      if (type.isF32())
-        mangledFuncName += "f";
-      else if (type.isF16())
-        mangledFuncName += "Dh";
-      else if (type.isF64())
-        mangledFuncName += "d";
-    };
-
-    for (auto type : operandTypes) {
-      if (auto vecType = dyn_cast<VectorType>(type)) {
-        mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
-        appendFloatToMangledFunc(vecType.getElementType());
-      } else
-        appendFloatToMangledFunc(type);
-    }
-
-    return mangledFuncName;
-  }
-
-  const StringRef nativeFunc;
-};
-
-void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
-                                                bool convertArith) {
-  patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_exp");
-  patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_cos");
-  patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
-      patterns.getContext(), "__spirv_ocl_native_exp2");
-  patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_log");
-  patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
-      patterns.getContext(), "__spirv_ocl_native_log2");
-  patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
-      patterns.getContext(), "__spirv_ocl_native_log10");
-  patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
-      patterns.getContext(), "__spirv_ocl_native_powr");
-  patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
-      patterns.getContext(), "__spirv_ocl_native_rsqrt");
-  patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_sin");
-  patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
-      patterns.getContext(), "__spirv_ocl_native_sqrt");
-  patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_tan");
-  if (convertArith)
-    patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
-        patterns.getContext(), "__spirv_ocl_native_divide");
-}
-
-namespace {
-struct ConvertMathToXeVMPass
-    : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
-  using Base::Base;
-  void runOnOperation() override;
-};
-} // namespace
-
-void ConvertMathToXeVMPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  populateMathToXeVMConversionPatterns(patterns, convertArith);
-  ConversionTarget target(getContext());
-  target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
-    signalPassFailure();
-}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
deleted file mode 100644
index d76627bb4201c..0000000000000
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ /dev/null
@@ -1,155 +0,0 @@
-// RUN: mlir-opt %s -convert-math-to-xevm \
-// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' 
-// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
-// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
-
-module @test_module {
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
-  //
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
-  //
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
-  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
-  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
-  // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
-  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
-  // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
-  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
-  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
-  // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
-
-  // CHECK-LABEL: func @math_ops
-  func.func @math_ops() {
-
-    %c1_f16 = arith.constant 1. : f16
-    %c1_f32 = arith.constant 1. : f32
-    %c1_f64 = arith.constant 1. : f64
-
-    // CHECK: math.exp
-    %exp_normal_f16 = math.exp %c1_f16 : f16
-    // CHECK: math.exp
-    %exp_normal_f32 = math.exp %c1_f32 : f32
-    // CHECK: math.exp
-    %exp_normal_f64 = math.exp %c1_f64 : f64
-
-    // Check float operations are converted properly:
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
-    %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
-    %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
-    %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
-    
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-    %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-    %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-    %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
-
-    // CHECK: math.exp
-    %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
-    // CHECK: math.exp
-    %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
-    // CHECK: math.exp
-    %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
-
-    // Check vector operations:
-
-    %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
-    %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
-    %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
-    %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
-    %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
-    %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
-    %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
-    %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
-    %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
-    %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
-
-    %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
-    %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
-    %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
-    %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
-
-    // Check unsupported vector sizes are not converted:
-
-    %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
-    %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
-
-    // CHECK: math.exp
-    %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
-    // CHECK: math.exp
-    %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
-
-    // Check fastmath flags propagate properly:
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
-    %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
-    %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
-    %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
-
-    // Check all other math operations:
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-    %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
-
-    // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-    %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-    %log_afn_f16 = math.log %c1_f16 fastmath<afn> : f16
-
-    // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-    %log2_afn_f32 = math.log2 %c1_f32 fastmath<afn> : f32
-
-    // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-    %log10_afn_f64 = math.log10 %c1_f64 fastmath<afn> : f64
-
-    // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16, f16) -> f16
-    %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath<afn> : f16
-
-    // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-    %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath<afn> : f64
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-    %sin_afn_f16 = math.sin %c1_f16 fastmath<afn> : f16
-
-    // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-    %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath<afn> : f32
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-    %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
-
-    %c6_9_f32 = arith.constant 6.9 : f32
-    %c7_f32 = arith.constant 7. : f32
-
-    // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
-    // CHECK-NO-ARITH: arith.divf
-    %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath<afn> : f32
-
-    return
-  }
-}
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
deleted file mode 100644
index 2492adafd6a50..0000000000000
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ /dev/null
@@ -1,118 +0,0 @@
-// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \
-// RUN:             -debug-only=serialize-to-isa 2> %t 
-// RUN: FileCheck --input-file=%t %s
-//
-// MathToXeVM pass generates OpenCL intrinsics function calls when converting
-// Math ops with `fastmath` attr to native function calls. It is assumed that
-// the SPIRV backend would correctly convert these intrinsics calls to OpenCL
-// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp).
-//
-// To ensure this assumption holds, this test verifies that the SPIRV backend
-// behaves as expected.
-
-module @test_ocl_intrinsics attributes {gpu.container_module} {
-  gpu.module @kernel [#xevm.target] {
-    llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} {
-      // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
-      // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
-      %c0_f16 = llvm.mlir.constant(0. : f16) : f16
-      // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32
-      // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]]
-      %c0_f32 = llvm.mlir.constant(0. : f32) : f32
-      // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
-      // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
-      %c0_f64 = llvm.mlir.constant(0. : f64) : f64
-
-      // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2
-      // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]]
-      %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64>
-      // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3
-      // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]]
-      %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32>
-      // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4
-      // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]]
-      %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64>
-      // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8
-      // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]]
-      %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64>
-      // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16
-      // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
-      %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>     
-
-      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
-      %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
-      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
-      %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
-      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
-      %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
-
-      // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
-      %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
-      // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
-      %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
-      // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
-      %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
-      // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
-      %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
-      // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
-      %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
-
-      // SPIRV backend does not currently handle fastmath flags: The SPIRV
-      // backend would need to generate OpDecorate calls to decorate math ops
-      // with FPFastMathMode/FPFastMathModeINTEL decorations.
-      //
-      // FIXME: When support for fastmath flags in the SPIRV backend is added, 
-      // add tests here to ensure fastmath flags are converted to the correct
-      // OpDecorate calls.
-      // 
-      // See:
-      // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
-      // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
-
-      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]]
-      %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]]
-      %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]]
-      %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]]
-      %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-      // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]]
-      %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
-      // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]]
-      %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16>
-      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]]
-      %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]]
-      %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
-      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]]
-      %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
-      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
-      %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
-      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]]
-      %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
-
-      llvm.return
-    }
-
-    llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
-    llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
-    llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
-    llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64>
-    llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32>
-    llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
-    llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
-    llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
-    llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
-    llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
-    llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
-    llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
-    llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64>
-    llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16>
-    llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
-    llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
-    llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
-    llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
-    llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
-  }
-}



More information about the Mlir-commits mailing list