[Mlir-commits] [mlir] [mlir][x86vector] Improve intrinsic operands creation (PR #138666)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 02:52:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.
---
Full diff: https://github.com/llvm/llvm-project/pull/138666.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+20-5)
- (modified) mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td (+4-2)
- (modified) mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp (+36-36)
- (modified) mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (+12-9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4f8301f9380b8..25d9c404f0181 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
#endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 5176f4a447b6e..cde9d1dce65ee 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
+ /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+ "const ::mlir::LLVMTypeConverter &":$typeConverter,
+ "::mlir::RewriterBase &":$rewriter),
/*methodBody=*/"",
- /*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
+ /*defaultImplementation=*/"return SmallVector<Value>(operands);"
>,
];
}
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 8d383b1f8103b..cc7ab7f3f3895 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
-static SmallVector<Value>
-getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
- RewriterBase &rewriter,
- const LLVMTypeConverter &typeConverter) {
- SmallVector<Value> operands;
- auto opType = memrefVal.getType();
-
- Type llvmStructType = typeConverter.convertType(opType);
- Value llvmStruct =
- rewriter
- .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
- .getResult(0);
- MemRefDescriptor memRefDescriptor(llvmStruct);
-
- Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
- operands.push_back(ptr);
-
- return operands;
+static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ MemRefDescriptor memRefDescriptor(buffer);
+ return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
}
LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
}
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
auto loc = getLoc();
+ Adaptor adaptor(operands, *this);
- auto opType = getA().getType();
+ auto opType = adaptor.getA().getType();
Value src;
- if (getSrc()) {
- src = getSrc();
- } else if (getConstantSrc()) {
- src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
+ if (adaptor.getSrc()) {
+ src = adaptor.getSrc();
+ } else if (adaptor.getConstantSrc()) {
+ src = rewriter.create<LLVM::ConstantOp>(loc, opType,
+ adaptor.getConstantSrcAttr());
} else {
auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
}
- return SmallVector<Value>{getA(), src, getK()};
+ return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
}
SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
- const LLVMTypeConverter &typeConverter) {
- SmallVector<Value> operands(getOperands());
+x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements.
Value scale =
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
- operands.push_back(scale);
+ intrinsicOperands.push_back(scale);
- return operands;
+ return intrinsicOperands;
}
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9ee44a63ba2e4..483c1f5c3e4c6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct OneToOneIntrinsicOpConversion
- : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
- using OpInterfaceRewritePattern<
- x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
+ : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+ using OpInterfaceConversionPattern<
+ x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
+ : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+ benefit),
typeConverter(typeConverter) {}
- LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
- PatternRewriter &rewriter) const override {
- return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
- op.getIntrinsicOperands(rewriter, typeConverter),
- typeConverter, rewriter);
+ LogicalResult
+ matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ return intrinsicRewrite(
+ op, rewriter.getStringAttr(op.getIntrinsicName()),
+ op.getIntrinsicOperands(operands, typeConverter, rewriter),
+ typeConverter, rewriter);
}
private:
``````````
</details>
https://github.com/llvm/llvm-project/pull/138666
More information about the Mlir-commits
mailing list