[Mlir-commits] [mlir] [mlir][llvm][x86vector] One-to-one intrinsic op interface (PR #140055)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri May 16 00:49:55 PDT 2025
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/140055
>From 63a78c2f962d4baf73c2177aadacf526b86fc0f6 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 30 Apr 2025 16:41:44 +0200
Subject: [PATCH 1/2] [mlir][llvm][x86vector] One-to-one intrinsic op interface
Adds an LLVMIR op interface that can used by external operations to
model LLVM intrinsics. Related 'op to llvm.call_intrinsic' rewriter
helper is moved into common LLVM conversion patterns.
The x86vector dialect is refactored to use the new common abstraction.
The one-to-one intrinsic op is tied to LLVM intrinsic call semantics.
Thus, the op interface, previously defined as a part of x86vector
dialect, is moved into the LLVMIR interfaces to allow other low-level
dialects to define operations abstracting specific intrinsic semantics
while minimizing infrastructure duplication.
Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 26 +++
.../mlir/Dialect/LLVMIR/LLVMInterfaces.h | 4 +
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 51 +++++
.../mlir/Dialect/X86Vector/X86Vector.td | 175 +++++++++---------
.../Dialect/X86Vector/X86VectorInterfaces.td | 49 +----
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 39 ++++
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp | 3 +
.../Transforms/LegalizeForLLVMExport.cpp | 82 +-------
8 files changed, 229 insertions(+), 200 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7a58e4fc2f984..ddbac85aa34fd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -30,6 +30,32 @@ LogicalResult oneToOneRewrite(
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
+/// specified name "intrinsic" and operands.
+///
+/// The rewrite performs a simple one-to-one matching between the op and LLVM
+/// intrinsic. For example:
+///
+/// ```mlir
+/// %res = intr.op %val : vector<16xf32>
+/// ```
+///
+/// can be converted to
+///
+/// ```mlir
+/// %res = llvm.call_intrinsic "intrinsic"(%val)
+/// ```
+///
+/// The provided operands must be LLVM-compatible.
+///
+/// Upholds a convention that multi-result operations get converted into an
+/// operation returning the LLVM IR structure type, in which case individual
+/// values are first extracted before replacing the original results.
+LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
+ ValueRange operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter);
+
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
index 05bd32c5d45da..d3e5408b73764 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
@@ -16,6 +16,10 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
namespace mlir {
+
+class LLVMTypeConverter;
+class RewriterBase;
+
namespace LLVM {
namespace detail {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 9db019af68b8e..2824f09dab6ce 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -435,6 +435,57 @@ def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
];
}
+def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOpInterface"> {
+ let description = [{
+ An interface for operations modelling LLVM intrinsics suitable for
+ 1-to-1 conversion.
+
+ An op implementing this interface can be directly replaced by a call
+ to a matching intrinsic function.
+ The op must ensure that the combinations of its arguments and results
+ have valid intrinsic counterparts.
+
+ For example, an operation supporting different inputs:
+ ```mlir
+ %res_v8 = intr.op %value_v8 : vector<8xf32>
+ %res_v16 = intr.op %value_v16 : vector<16xf32>
+ ```
+ can be converted to the following intrinsic calls:
+ ```mlir
+ %res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8)
+ %res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16)
+ ```
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns mangled LLVM intrinsic function name matching the operation
+ variant.
+ }],
+ /*retType=*/"std::string",
+ /*methodName=*/"getIntrinsicName"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns operands for a corresponding LLVM intrinsic.
+
+ Additional operations may be created to facilitate mapping
+ between the source operands and the target intrinsic.
+ }],
+ /*retType=*/"SmallVector<Value>",
+ /*methodName=*/"getIntrinsicOperands",
+ /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+ "const ::mlir::LLVMTypeConverter &":$typeConverter,
+ "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return SmallVector<Value>(operands);"
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 25d9c404f0181..66213b0041958 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -40,14 +40,15 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
- // then be removed from assemblyFormat.
- AllTypesMatch<["a", "dst"]>,
- TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
- "dst", "k",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">]> {
+ IntrinsicOpInterface,
+ // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
+ // then be removed from assemblyFormat.
+ AllTypesMatch<["a", "dst"]>,
+ TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
+ "dst", "k",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">
+ ]> {
let summary = "Masked compress op";
let description = [{
The mask.compress op is an AVX512 specific op that can lower to the
@@ -75,14 +76,13 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
" `:` type($dst) (`,` type($src)^)?";
let hasVerifier = 1;
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
// Call the baseline overloaded intrisic.
// Final overload name mangling is resolved by the created function call.
return "llvm.x86.avx512.mask.compress";
}
- }];
- let extraClassDeclaration = [{
+
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -95,12 +95,13 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
//----------------------------------------------------------------------------//
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["src", "a", "dst"]>,
- TypesMatchWith<"imm has the same number of bits as elements in dst",
- "dst", "imm",
- "IntegerType::get($_self.getContext(), "
- "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
+ IntrinsicOpInterface,
+ AllTypesMatch<["src", "a", "dst"]>,
+ TypesMatchWith<"imm has the same number of bits as elements in dst",
+ "dst", "imm",
+ "IntegerType::get($_self.getContext(), "
+ "(::llvm::cast<VectorType>($_self).getShape()[0]))">
+ ]> {
let summary = "Masked roundscale op";
let description = [{
The mask.rndscale op is an AVX512 specific op that can lower to the proper
@@ -126,8 +127,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
let assemblyFormat =
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.mask.rndscale";
VectorType vecType = getSrc().getType();
Type elemType = vecType.getElementType();
@@ -146,12 +147,13 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
//----------------------------------------------------------------------------//
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["src", "a", "b", "dst"]>,
- TypesMatchWith<"k has the same number of bits as elements in dst",
- "dst", "k",
- "IntegerType::get($_self.getContext(), "
- "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
+ IntrinsicOpInterface,
+ AllTypesMatch<["src", "a", "b", "dst"]>,
+ TypesMatchWith<"k has the same number of bits as elements in dst",
+ "dst", "k",
+ "IntegerType::get($_self.getContext(), "
+ "(::llvm::cast<VectorType>($_self).getShape()[0]))">
+ ]> {
let summary = "ScaleF op";
let description = [{
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
@@ -178,8 +180,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
let assemblyFormat =
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.mask.scalef";
VectorType vecType = getSrc().getType();
Type elemType = vecType.getElementType();
@@ -198,18 +200,19 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
//----------------------------------------------------------------------------//
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["a", "b"]>,
- TypesMatchWith<"k1 has the same number of bits as elements in a",
- "a", "k1",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">,
- TypesMatchWith<"k2 has the same number of bits as elements in b",
- // Should use `b` instead of `a`, but that would require
- // adding `type($b)` to assemblyFormat.
- "a", "k2",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">]> {
+ IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ TypesMatchWith<"k1 has the same number of bits as elements in a",
+ "a", "k1",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">,
+ TypesMatchWith<"k2 has the same number of bits as elements in b",
+ // Should use `b` instead of `a`, but that would require
+ // adding `type($b)` to assemblyFormat.
+ "a", "k2",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">
+ ]> {
let summary = "Vp2Intersect op";
let description = [{
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
@@ -234,8 +237,8 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.vp2intersect";
VectorType vecType = getA().getType();
Type elemType = vecType.getElementType();
@@ -254,13 +257,14 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
//----------------------------------------------------------------------------//
def DotBF16Op : AVX512_Op<"dot", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "dst"]>,
- TypesMatchWith<"`a` has twice an many elements as `src`",
- "src", "a",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
- "BFloat16Type::get($_self.getContext()))">]> {
+ IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "dst"]>,
+ TypesMatchWith<"`a` has twice an many elements as `src`",
+ "src", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
+ "BFloat16Type::get($_self.getContext()))">
+ ]> {
let summary = "Dot BF16 op";
let description = [{
The `dot` op is an AVX512-BF16 specific op that can lower to the proper
@@ -286,8 +290,8 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
let assemblyFormat =
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
VectorType vecType = getSrc().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
@@ -303,8 +307,9 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
//----------------------------------------------------------------------------//
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllElementCountsMatch<["a", "dst"]>]> {
+ IntrinsicOpInterface,
+ AllElementCountsMatch<["a", "dst"]>
+ ]> {
let summary = "Convert packed F32 to packed BF16 Data.";
let description = [{
The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
@@ -326,8 +331,8 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
let assemblyFormat =
"$a attr-dict `:` type($a) `->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
VectorType vecType = getA().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
@@ -357,15 +362,16 @@ class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def RsqrtOp : AVX_Op<"rsqrt", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- SameOperandsAndResultType]> {
+ IntrinsicOpInterface,
+ SameOperandsAndResultType
+ ]> {
let summary = "Rsqrt";
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
let assemblyFormat = "$a attr-dict `:` type($a)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
return "llvm.x86.avx.rsqrt.ps.256";
}
}];
@@ -376,8 +382,9 @@ def RsqrtOp : AVX_Op<"rsqrt", [Pure,
//----------------------------------------------------------------------------//
def DotOp : AVX_LowOp<"dot", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- SameOperandsAndResultType]> {
+ IntrinsicOpInterface,
+ SameOperandsAndResultType
+ ]> {
let summary = "Dot";
let description = [{
Computes the 4-way dot products of the lower and higher parts of the source
@@ -400,13 +407,12 @@ def DotOp : AVX_LowOp<"dot", [Pure,
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
// Only one variant is supported right now - no extra mangling.
return "llvm.x86.avx.dp.ps.256";
}
- }];
- let extraClassDeclaration = [{
+
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -418,8 +424,11 @@ def DotOp : AVX_LowOp<"dot", [Pure,
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
-def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def BcstToPackedF32Op
+ : AVX_Op<"bcst_to_f32.packed", [
+ MemoryEffects<[MemRead]>,
+ IntrinsicOpInterface
+ ]> {
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -440,8 +449,8 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -455,9 +464,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -470,8 +477,11 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
//------------------------------------------------------------------------------//
-def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def CvtPackedEvenIndexedToF32Op
+ : AVX_Op<"cvt.packed.even.indexed_to_f32", [
+ MemoryEffects<[MemRead]>,
+ IntrinsicOpInterface
+ ]> {
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -491,8 +501,8 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -506,9 +516,7 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -516,8 +524,11 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
}];
}
-def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def CvtPackedOddIndexedToF32Op
+ : AVX_Op<"cvt.packed.odd.indexed_to_f32", [
+ MemoryEffects<[MemRead]>,
+ IntrinsicOpInterface
+ ]> {
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -537,8 +548,8 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -552,9 +563,7 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index cde9d1dce65ee..e057618dc0e98 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -14,57 +14,18 @@
#define X86VECTOR_INTERFACES
include "mlir/IR/Interfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
//===----------------------------------------------------------------------===//
-// One-to-One Intrinsic Interface
+// Intrinsic Interface
//===----------------------------------------------------------------------===//
-def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
+def IntrinsicOpInterface
+ : OpInterface<"IntrinsicOp", [OneToOneIntrinsicOpInterface]> {
let description = [{
- Interface for 1-to-1 conversion of an operation into LLVM intrinsics.
-
- An op implementing this interface can be simply replaced by a call
- to a matching intrinsic function.
- The op must ensure that the combinations of their arguments and results
- have valid intrinsic counterparts.
-
- For example, an operation supporting different vector widths:
- ```mlir
- %res_v8 = x86vector.op %value_v8 : vector<8xf32>
- %res_v16 = x86vector.op %value_v16 : vector<16xf32>
- ```
- can be converted to the following intrinsic calls:
- ```mlir
- %res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8)
- %res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16)
- ```
+ A wrapper interface for operations representing x86 LLVM intrinsics.
}];
let cppNamespace = "::mlir::x86vector";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns mangled LLVM intrinsic function name matching the operation
- variant.
- }],
- /*retType=*/"std::string",
- /*methodName=*/"getIntrinsicName"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns operands for a corresponding LLVM intrinsic.
-
- Additional operations may be created to facilitate mapping
- between the source operands and the target intrinsic.
- }],
- /*retType=*/"SmallVector<Value>",
- /*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
- "const ::mlir::LLVMTypeConverter &":$typeConverter,
- "::mlir::RewriterBase &":$rewriter),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return SmallVector<Value>(operands);"
- >,
- ];
}
#endif // X86VECTOR_INTERFACES
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 0505214de2015..dbf5a9e59ee9f 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -382,6 +382,45 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return success();
}
+LogicalResult LLVM::detail::intrinsicRewrite(
+ Operation *op, StringRef intrinsic, ValueRange operands,
+ const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
+ auto loc = op->getLoc();
+
+ if (!llvm::all_of(operands, [](Value value) {
+ return LLVM::isCompatibleType(value.getType());
+ }))
+ return failure();
+
+ unsigned numResults = op->getNumResults();
+ Type resType;
+ if (numResults != 0)
+ resType = typeConverter.packOperationResults(op->getResultTypes());
+
+ auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
+ loc, resType, rewriter.getStringAttr(intrinsic), operands);
+ // Propagate attributes.
+ callIntrOp->setAttrs(op->getAttrDictionary());
+
+ if (numResults <= 1) {
+ // Directly replace the original op.
+ rewriter.replaceOp(op, callIntrOp);
+ return success();
+ }
+
+ // Extract individual results from packed structure and use them as
+ // replacements.
+ SmallVector<Value, 4> results;
+ results.reserve(numResults);
+ Value intrRes = callIntrOp.getResults();
+ for (unsigned i = 0; i < numResults; ++i) {
+ results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
+ }
+ rewriter.replaceOp(op, results);
+
+ return success();
+}
+
static unsigned getBitWidth(Type type) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index fe8aa1a918a98..769ef0030644f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -11,7 +11,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::LLVM;
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 483c1f5c3e4c6..331d9aee92665 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -20,84 +20,23 @@ using namespace mlir::x86vector;
namespace {
-/// Replaces an operation with a call to an LLVM intrinsic with the specified
-/// name and operands.
-///
-/// The rewrite performs a simple one-to-one matching between the op and LLVM
-/// intrinsic. For example:
-///
-/// ```mlir
-/// %res = x86vector.op %val : vector<16xf32>
-/// ```
-///
-/// can be converted to
-///
-/// ```mlir
-/// %res = llvm.call_intrinsic "intrinsic"(%val)
-/// ```
-///
-/// The provided operands must be LLVM-compatible.
-///
-/// Upholds a convention that multi-result operations get converted into an
-/// operation returning the LLVM IR structure type, in which case individual
-/// values are first extracted before replacing the original results.
-LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
- ValueRange operands,
- const LLVMTypeConverter &typeConverter,
- PatternRewriter &rewriter) {
- auto loc = op->getLoc();
-
- if (!llvm::all_of(operands, [](Value value) {
- return LLVM::isCompatibleType(value.getType());
- }))
- return rewriter.notifyMatchFailure(op, "Expects LLVM-compatible types.");
-
- unsigned numResults = op->getNumResults();
- Type resType;
- if (numResults != 0)
- resType = typeConverter.packOperationResults(op->getResultTypes());
-
- auto callIntrOp =
- rewriter.create<LLVM::CallIntrinsicOp>(loc, resType, intrinsic, operands);
- // Propagate attributes.
- callIntrOp->setAttrs(op->getAttrDictionary());
-
- if (numResults <= 1) {
- // Directly replace the original op.
- rewriter.replaceOp(op, callIntrOp);
- return success();
- }
-
- // Extract individual results from packed structure and use them as
- // replacements.
- SmallVector<Value, 4> results;
- results.reserve(numResults);
- Value intrRes = callIntrOp.getResults();
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
- }
- rewriter.replaceOp(op, results);
-
- return success();
-}
-
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
-struct OneToOneIntrinsicOpConversion
- : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+struct IntrinsicOpConversion
+ : public OpInterfaceConversionPattern<x86vector::IntrinsicOp> {
using OpInterfaceConversionPattern<
- x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
+ x86vector::IntrinsicOp>::OpInterfaceConversionPattern;
- OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
+ IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
benefit),
typeConverter(typeConverter) {}
LogicalResult
- matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+ matchAndRewrite(x86vector::IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- return intrinsicRewrite(
+ return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
@@ -112,13 +51,10 @@ struct OneToOneIntrinsicOpConversion
/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<OneToOneIntrinsicOpConversion>(converter);
+ patterns.add<IntrinsicOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
- Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
- CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
- BcstToPackedF32Op, RsqrtOp, DotOp>();
+ target.addIllegalDialect<X86VectorDialect>();
}
>From 1d11fe015ddc8a4be481dd1a96c0fe04e036a5ab Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 16 May 2025 09:49:39 +0200
Subject: [PATCH 2/2] Address comments
---
.../mlir/Dialect/X86Vector/X86Vector.td | 22 +++++++++----------
.../Dialect/X86Vector/X86VectorInterfaces.td | 4 ++--
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 3 +--
.../Transforms/LegalizeForLLVMExport.cpp | 14 ++++++------
4 files changed, 21 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 66213b0041958..3bf0be0a716aa 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -40,7 +40,7 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
// then be removed from assemblyFormat.
AllTypesMatch<["a", "dst"]>,
@@ -95,7 +95,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
//----------------------------------------------------------------------------//
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
AllTypesMatch<["src", "a", "dst"]>,
TypesMatchWith<"imm has the same number of bits as elements in dst",
"dst", "imm",
@@ -147,7 +147,7 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
//----------------------------------------------------------------------------//
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
AllTypesMatch<["src", "a", "b", "dst"]>,
TypesMatchWith<"k has the same number of bits as elements in dst",
"dst", "k",
@@ -200,7 +200,7 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
//----------------------------------------------------------------------------//
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
AllTypesMatch<["a", "b"]>,
TypesMatchWith<"k1 has the same number of bits as elements in a",
"a", "k1",
@@ -257,7 +257,7 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
//----------------------------------------------------------------------------//
def DotBF16Op : AVX512_Op<"dot", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
AllTypesMatch<["a", "b"]>,
AllTypesMatch<["src", "dst"]>,
TypesMatchWith<"`a` has twice an many elements as `src`",
@@ -307,7 +307,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
//----------------------------------------------------------------------------//
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
AllElementCountsMatch<["a", "dst"]>
]> {
let summary = "Convert packed F32 to packed BF16 Data.";
@@ -362,7 +362,7 @@ class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def RsqrtOp : AVX_Op<"rsqrt", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
SameOperandsAndResultType
]> {
let summary = "Rsqrt";
@@ -382,7 +382,7 @@ def RsqrtOp : AVX_Op<"rsqrt", [Pure,
//----------------------------------------------------------------------------//
def DotOp : AVX_LowOp<"dot", [Pure,
- IntrinsicOpInterface,
+ X86IntrinsicOpInterface,
SameOperandsAndResultType
]> {
let summary = "Dot";
@@ -427,7 +427,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
def BcstToPackedF32Op
: AVX_Op<"bcst_to_f32.packed", [
MemoryEffects<[MemRead]>,
- IntrinsicOpInterface
+ X86IntrinsicOpInterface
]> {
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
let description = [{
@@ -480,7 +480,7 @@ def BcstToPackedF32Op
def CvtPackedEvenIndexedToF32Op
: AVX_Op<"cvt.packed.even.indexed_to_f32", [
MemoryEffects<[MemRead]>,
- IntrinsicOpInterface
+ X86IntrinsicOpInterface
]> {
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
let description = [{
@@ -527,7 +527,7 @@ def CvtPackedEvenIndexedToF32Op
def CvtPackedOddIndexedToF32Op
: AVX_Op<"cvt.packed.odd.indexed_to_f32", [
MemoryEffects<[MemRead]>,
- IntrinsicOpInterface
+ X86IntrinsicOpInterface
]> {
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index e057618dc0e98..f4579dcfa9566 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -20,8 +20,8 @@ include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
// Intrinsic Interface
//===----------------------------------------------------------------------===//
-def IntrinsicOpInterface
- : OpInterface<"IntrinsicOp", [OneToOneIntrinsicOpInterface]> {
+def X86IntrinsicOpInterface
+ : OpInterface<"X86IntrinsicOp", [OneToOneIntrinsicOpInterface]> {
let description = [{
A wrapper interface for operations representing x86 LLVM intrinsics.
}];
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index dbf5a9e59ee9f..b975d6f7a6a3c 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -413,9 +413,8 @@ LogicalResult LLVM::detail::intrinsicRewrite(
SmallVector<Value, 4> results;
results.reserve(numResults);
Value intrRes = callIntrOp.getResults();
- for (unsigned i = 0; i < numResults; ++i) {
+ for (unsigned i = 0; i < numResults; ++i)
results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
- }
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 331d9aee92665..b2fc2f3f40e8c 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -22,19 +22,19 @@ namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
-struct IntrinsicOpConversion
- : public OpInterfaceConversionPattern<x86vector::IntrinsicOp> {
+struct X86IntrinsicOpConversion
+ : public OpInterfaceConversionPattern<x86vector::X86IntrinsicOp> {
using OpInterfaceConversionPattern<
- x86vector::IntrinsicOp>::OpInterfaceConversionPattern;
+ x86vector::X86IntrinsicOp>::OpInterfaceConversionPattern;
- IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
+ X86IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
benefit),
typeConverter(typeConverter) {}
LogicalResult
- matchAndRewrite(x86vector::IntrinsicOp op, ArrayRef<Value> operands,
+ matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
@@ -51,7 +51,7 @@ struct IntrinsicOpConversion
/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<IntrinsicOpConversion>(converter);
+ patterns.add<X86IntrinsicOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
More information about the Mlir-commits
mailing list