[Mlir-commits] [mlir] c798e19 - [mlir][llvm][x86vector] One-to-one intrinsic op interface (#140055)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 18 22:42:29 PDT 2025
Author: Adam Siemieniuk
Date: 2025-05-19T07:42:25+02:00
New Revision: c798e195409a2d6fabbb166285d9dfe8ca7599a7
URL: https://github.com/llvm/llvm-project/commit/c798e195409a2d6fabbb166285d9dfe8ca7599a7
DIFF: https://github.com/llvm/llvm-project/commit/c798e195409a2d6fabbb166285d9dfe8ca7599a7.diff
LOG: [mlir][llvm][x86vector] One-to-one intrinsic op interface (#140055)
Adds an LLVMIR op interface that can used by external operations to
model LLVM intrinsics. Related 'op to llvm.call_intrinsic' rewriter
helper is moved into common LLVM conversion patterns. The x86vector
dialect is refactored to use the new common abstraction.
The one-to-one intrinsic op is tied to LLVM intrinsic call semantics.
Thus, the op interface, previously defined as a part of x86vector
dialect, is moved into the LLVMIR interfaces to allow other low-level
dialects to define operations abstracting specific intrinsic semantics
while minimizing infrastructure duplication.
Related RFC:
https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/6
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7a58e4fc2f984..ddbac85aa34fd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -30,6 +30,32 @@ LogicalResult oneToOneRewrite(
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
+/// specified name "intrinsic" and operands.
+///
+/// The rewrite performs a simple one-to-one matching between the op and LLVM
+/// intrinsic. For example:
+///
+/// ```mlir
+/// %res = intr.op %val : vector<16xf32>
+/// ```
+///
+/// can be converted to
+///
+/// ```mlir
+/// %res = llvm.call_intrinsic "intrinsic"(%val)
+/// ```
+///
+/// The provided operands must be LLVM-compatible.
+///
+/// Upholds a convention that multi-result operations get converted into an
+/// operation returning the LLVM IR structure type, in which case individual
+/// values are first extracted before replacing the original results.
+LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
+ ValueRange operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter);
+
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
index 05bd32c5d45da..d3e5408b73764 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
@@ -16,6 +16,10 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
namespace mlir {
+
+class LLVMTypeConverter;
+class RewriterBase;
+
namespace LLVM {
namespace detail {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 9db019af68b8e..2824f09dab6ce 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -435,6 +435,57 @@ def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
];
}
+def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOpInterface"> {
+ let description = [{
+ An interface for operations modelling LLVM intrinsics suitable for
+ 1-to-1 conversion.
+
+ An op implementing this interface can be directly replaced by a call
+ to a matching intrinsic function.
+ The op must ensure that the combinations of its arguments and results
+ have valid intrinsic counterparts.
+
+ For example, an operation supporting
diff erent inputs:
+ ```mlir
+ %res_v8 = intr.op %value_v8 : vector<8xf32>
+ %res_v16 = intr.op %value_v16 : vector<16xf32>
+ ```
+ can be converted to the following intrinsic calls:
+ ```mlir
+ %res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8)
+ %res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16)
+ ```
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns mangled LLVM intrinsic function name matching the operation
+ variant.
+ }],
+ /*retType=*/"std::string",
+ /*methodName=*/"getIntrinsicName"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns operands for a corresponding LLVM intrinsic.
+
+ Additional operations may be created to facilitate mapping
+ between the source operands and the target intrinsic.
+ }],
+ /*retType=*/"SmallVector<Value>",
+ /*methodName=*/"getIntrinsicOperands",
+ /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+ "const ::mlir::LLVMTypeConverter &":$typeConverter,
+ "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return SmallVector<Value>(operands);"
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 25d9c404f0181..3bf0be0a716aa 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -40,14 +40,15 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
- // then be removed from assemblyFormat.
- AllTypesMatch<["a", "dst"]>,
- TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
- "dst", "k",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">]> {
+ X86IntrinsicOpInterface,
+ // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
+ // then be removed from assemblyFormat.
+ AllTypesMatch<["a", "dst"]>,
+ TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
+ "dst", "k",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">
+ ]> {
let summary = "Masked compress op";
let description = [{
The mask.compress op is an AVX512 specific op that can lower to the
@@ -75,14 +76,13 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
" `:` type($dst) (`,` type($src)^)?";
let hasVerifier = 1;
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
// Call the baseline overloaded intrisic.
// Final overload name mangling is resolved by the created function call.
return "llvm.x86.avx512.mask.compress";
}
- }];
- let extraClassDeclaration = [{
+
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -95,12 +95,13 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
//----------------------------------------------------------------------------//
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["src", "a", "dst"]>,
- TypesMatchWith<"imm has the same number of bits as elements in dst",
- "dst", "imm",
- "IntegerType::get($_self.getContext(), "
- "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["src", "a", "dst"]>,
+ TypesMatchWith<"imm has the same number of bits as elements in dst",
+ "dst", "imm",
+ "IntegerType::get($_self.getContext(), "
+ "(::llvm::cast<VectorType>($_self).getShape()[0]))">
+ ]> {
let summary = "Masked roundscale op";
let description = [{
The mask.rndscale op is an AVX512 specific op that can lower to the proper
@@ -126,8 +127,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
let assemblyFormat =
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.mask.rndscale";
VectorType vecType = getSrc().getType();
Type elemType = vecType.getElementType();
@@ -146,12 +147,13 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
//----------------------------------------------------------------------------//
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["src", "a", "b", "dst"]>,
- TypesMatchWith<"k has the same number of bits as elements in dst",
- "dst", "k",
- "IntegerType::get($_self.getContext(), "
- "(::llvm::cast<VectorType>($_self).getShape()[0]))">]> {
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["src", "a", "b", "dst"]>,
+ TypesMatchWith<"k has the same number of bits as elements in dst",
+ "dst", "k",
+ "IntegerType::get($_self.getContext(), "
+ "(::llvm::cast<VectorType>($_self).getShape()[0]))">
+ ]> {
let summary = "ScaleF op";
let description = [{
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
@@ -178,8 +180,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
let assemblyFormat =
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.mask.scalef";
VectorType vecType = getSrc().getType();
Type elemType = vecType.getElementType();
@@ -198,18 +200,19 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
//----------------------------------------------------------------------------//
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["a", "b"]>,
- TypesMatchWith<"k1 has the same number of bits as elements in a",
- "a", "k1",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">,
- TypesMatchWith<"k2 has the same number of bits as elements in b",
- // Should use `b` instead of `a`, but that would require
- // adding `type($b)` to assemblyFormat.
- "a", "k2",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
- "IntegerType::get($_self.getContext(), 1))">]> {
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ TypesMatchWith<"k1 has the same number of bits as elements in a",
+ "a", "k1",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">,
+ TypesMatchWith<"k2 has the same number of bits as elements in b",
+ // Should use `b` instead of `a`, but that would require
+ // adding `type($b)` to assemblyFormat.
+ "a", "k2",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">
+ ]> {
let summary = "Vp2Intersect op";
let description = [{
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
@@ -234,8 +237,8 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512.vp2intersect";
VectorType vecType = getA().getType();
Type elemType = vecType.getElementType();
@@ -254,13 +257,14 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
//----------------------------------------------------------------------------//
def DotBF16Op : AVX512_Op<"dot", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "dst"]>,
- TypesMatchWith<"`a` has twice an many elements as `src`",
- "src", "a",
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
- "BFloat16Type::get($_self.getContext()))">]> {
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "dst"]>,
+ TypesMatchWith<"`a` has twice an many elements as `src`",
+ "src", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
+ "BFloat16Type::get($_self.getContext()))">
+ ]> {
let summary = "Dot BF16 op";
let description = [{
The `dot` op is an AVX512-BF16 specific op that can lower to the proper
@@ -286,8 +290,8 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
let assemblyFormat =
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
VectorType vecType = getSrc().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
@@ -303,8 +307,9 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
//----------------------------------------------------------------------------//
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- AllElementCountsMatch<["a", "dst"]>]> {
+ X86IntrinsicOpInterface,
+ AllElementCountsMatch<["a", "dst"]>
+ ]> {
let summary = "Convert packed F32 to packed BF16 Data.";
let description = [{
The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
@@ -326,8 +331,8 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
let assemblyFormat =
"$a attr-dict `:` type($a) `->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
VectorType vecType = getA().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
@@ -357,15 +362,16 @@ class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
//----------------------------------------------------------------------------//
def RsqrtOp : AVX_Op<"rsqrt", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- SameOperandsAndResultType]> {
+ X86IntrinsicOpInterface,
+ SameOperandsAndResultType
+ ]> {
let summary = "Rsqrt";
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
let assemblyFormat = "$a attr-dict `:` type($a)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
return "llvm.x86.avx.rsqrt.ps.256";
}
}];
@@ -376,8 +382,9 @@ def RsqrtOp : AVX_Op<"rsqrt", [Pure,
//----------------------------------------------------------------------------//
def DotOp : AVX_LowOp<"dot", [Pure,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
- SameOperandsAndResultType]> {
+ X86IntrinsicOpInterface,
+ SameOperandsAndResultType
+ ]> {
let summary = "Dot";
let description = [{
Computes the 4-way dot products of the lower and higher parts of the source
@@ -400,13 +407,12 @@ def DotOp : AVX_LowOp<"dot", [Pure,
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
// Only one variant is supported right now - no extra mangling.
return "llvm.x86.avx.dp.ps.256";
}
- }];
- let extraClassDeclaration = [{
+
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -418,8 +424,11 @@ def DotOp : AVX_LowOp<"dot", [Pure,
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
-def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def BcstToPackedF32Op
+ : AVX_Op<"bcst_to_f32.packed", [
+ MemoryEffects<[MemRead]>,
+ X86IntrinsicOpInterface
+ ]> {
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -440,8 +449,8 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -455,9 +464,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -470,8 +477,11 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
//------------------------------------------------------------------------------//
-def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def CvtPackedEvenIndexedToF32Op
+ : AVX_Op<"cvt.packed.even.indexed_to_f32", [
+ MemoryEffects<[MemRead]>,
+ X86IntrinsicOpInterface
+ ]> {
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -491,8 +501,8 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -506,9 +516,7 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
@@ -516,8 +524,11 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
}];
}
-def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+def CvtPackedOddIndexedToF32Op
+ : AVX_Op<"cvt.packed.odd.indexed_to_f32", [
+ MemoryEffects<[MemRead]>,
+ X86IntrinsicOpInterface
+ ]> {
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
@@ -537,8 +548,8 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
@@ -552,9 +563,7 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
intr += std::to_string(opBitWidth);
return intr;
}
- }];
- let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index cde9d1dce65ee..6fef87e27361e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -14,57 +14,18 @@
#define X86VECTOR_INTERFACES
include "mlir/IR/Interfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
//===----------------------------------------------------------------------===//
-// One-to-One Intrinsic Interface
+// X86 Intrinsic Interface
//===----------------------------------------------------------------------===//
-def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
+def X86IntrinsicOpInterface
+ : OpInterface<"X86IntrinsicOp", [OneToOneIntrinsicOpInterface]> {
let description = [{
- Interface for 1-to-1 conversion of an operation into LLVM intrinsics.
-
- An op implementing this interface can be simply replaced by a call
- to a matching intrinsic function.
- The op must ensure that the combinations of their arguments and results
- have valid intrinsic counterparts.
-
- For example, an operation supporting
diff erent vector widths:
- ```mlir
- %res_v8 = x86vector.op %value_v8 : vector<8xf32>
- %res_v16 = x86vector.op %value_v16 : vector<16xf32>
- ```
- can be converted to the following intrinsic calls:
- ```mlir
- %res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8)
- %res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16)
- ```
+ A wrapper interface for operations representing x86 LLVM intrinsics.
}];
let cppNamespace = "::mlir::x86vector";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns mangled LLVM intrinsic function name matching the operation
- variant.
- }],
- /*retType=*/"std::string",
- /*methodName=*/"getIntrinsicName"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns operands for a corresponding LLVM intrinsic.
-
- Additional operations may be created to facilitate mapping
- between the source operands and the target intrinsic.
- }],
- /*retType=*/"SmallVector<Value>",
- /*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
- "const ::mlir::LLVMTypeConverter &":$typeConverter,
- "::mlir::RewriterBase &":$rewriter),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return SmallVector<Value>(operands);"
- >,
- ];
}
#endif // X86VECTOR_INTERFACES
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 0505214de2015..b975d6f7a6a3c 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -382,6 +382,44 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return success();
}
+LogicalResult LLVM::detail::intrinsicRewrite(
+ Operation *op, StringRef intrinsic, ValueRange operands,
+ const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
+ auto loc = op->getLoc();
+
+ if (!llvm::all_of(operands, [](Value value) {
+ return LLVM::isCompatibleType(value.getType());
+ }))
+ return failure();
+
+ unsigned numResults = op->getNumResults();
+ Type resType;
+ if (numResults != 0)
+ resType = typeConverter.packOperationResults(op->getResultTypes());
+
+ auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
+ loc, resType, rewriter.getStringAttr(intrinsic), operands);
+ // Propagate attributes.
+ callIntrOp->setAttrs(op->getAttrDictionary());
+
+ if (numResults <= 1) {
+ // Directly replace the original op.
+ rewriter.replaceOp(op, callIntrOp);
+ return success();
+ }
+
+ // Extract individual results from packed structure and use them as
+ // replacements.
+ SmallVector<Value, 4> results;
+ results.reserve(numResults);
+ Value intrRes = callIntrOp.getResults();
+ for (unsigned i = 0; i < numResults; ++i)
+ results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
+ rewriter.replaceOp(op, results);
+
+ return success();
+}
+
static unsigned getBitWidth(Type type) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index fe8aa1a918a98..769ef0030644f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -11,7 +11,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::LLVM;
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 483c1f5c3e4c6..b2fc2f3f40e8c 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -20,84 +20,23 @@ using namespace mlir::x86vector;
namespace {
-/// Replaces an operation with a call to an LLVM intrinsic with the specified
-/// name and operands.
-///
-/// The rewrite performs a simple one-to-one matching between the op and LLVM
-/// intrinsic. For example:
-///
-/// ```mlir
-/// %res = x86vector.op %val : vector<16xf32>
-/// ```
-///
-/// can be converted to
-///
-/// ```mlir
-/// %res = llvm.call_intrinsic "intrinsic"(%val)
-/// ```
-///
-/// The provided operands must be LLVM-compatible.
-///
-/// Upholds a convention that multi-result operations get converted into an
-/// operation returning the LLVM IR structure type, in which case individual
-/// values are first extracted before replacing the original results.
-LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
- ValueRange operands,
- const LLVMTypeConverter &typeConverter,
- PatternRewriter &rewriter) {
- auto loc = op->getLoc();
-
- if (!llvm::all_of(operands, [](Value value) {
- return LLVM::isCompatibleType(value.getType());
- }))
- return rewriter.notifyMatchFailure(op, "Expects LLVM-compatible types.");
-
- unsigned numResults = op->getNumResults();
- Type resType;
- if (numResults != 0)
- resType = typeConverter.packOperationResults(op->getResultTypes());
-
- auto callIntrOp =
- rewriter.create<LLVM::CallIntrinsicOp>(loc, resType, intrinsic, operands);
- // Propagate attributes.
- callIntrOp->setAttrs(op->getAttrDictionary());
-
- if (numResults <= 1) {
- // Directly replace the original op.
- rewriter.replaceOp(op, callIntrOp);
- return success();
- }
-
- // Extract individual results from packed structure and use them as
- // replacements.
- SmallVector<Value, 4> results;
- results.reserve(numResults);
- Value intrRes = callIntrOp.getResults();
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
- }
- rewriter.replaceOp(op, results);
-
- return success();
-}
-
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
-struct OneToOneIntrinsicOpConversion
- : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+struct X86IntrinsicOpConversion
+ : public OpInterfaceConversionPattern<x86vector::X86IntrinsicOp> {
using OpInterfaceConversionPattern<
- x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
+ x86vector::X86IntrinsicOp>::OpInterfaceConversionPattern;
- OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
+ X86IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
benefit),
typeConverter(typeConverter) {}
LogicalResult
- matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+ matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- return intrinsicRewrite(
+ return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
@@ -112,13 +51,10 @@ struct OneToOneIntrinsicOpConversion
/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<OneToOneIntrinsicOpConversion>(converter);
+ patterns.add<X86IntrinsicOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
- Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
- CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
- BcstToPackedF32Op, RsqrtOp, DotOp>();
+ target.addIllegalDialect<X86VectorDialect>();
}
More information about the Mlir-commits
mailing list