[Mlir-commits] [mlir] [mlir][x86vector] Improve intrinsic operands creation (PR #138666)

Adam Siemieniuk llvmlistbot at llvm.org
Tue May 6 02:52:00 PDT 2025


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/138666

Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.

>From 8268feab1250137af9e41288db9dd3b7f48ca1ed Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 6 May 2025 10:23:38 +0200
Subject: [PATCH] [mlir][x86vector] Improve intrinsic operands creation

Refactors intrinsic op interface to delegate initial operands mapping
to the dialect converter and allow intrinsic operands getters to only
perform the last mile post-processing.
---
 .../mlir/Dialect/X86Vector/X86Vector.td       | 25 +++++--
 .../Dialect/X86Vector/X86VectorInterfaces.td  |  6 +-
 .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 72 +++++++++----------
 .../Transforms/LegalizeForLLVMExport.cpp      | 21 +++---
 4 files changed, 72 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4f8301f9380b8..25d9c404f0181 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 
 }
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 #endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 5176f4a447b6e..cde9d1dce65ee 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
       }],
       /*retType=*/"SmallVector<Value>",
       /*methodName=*/"getIntrinsicOperands",
-      /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
+      /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+                    "const ::mlir::LLVMTypeConverter &":$typeConverter,
+                    "::mlir::RewriterBase &":$rewriter),
       /*methodBody=*/"",
-      /*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
+      /*defaultImplementation=*/"return SmallVector<Value>(operands);"
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 8d383b1f8103b..cc7ab7f3f3895 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
       >();
 }
 
-static SmallVector<Value>
-getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
-                 RewriterBase &rewriter,
-                 const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands;
-  auto opType = memrefVal.getType();
-
-  Type llvmStructType = typeConverter.convertType(opType);
-  Value llvmStruct =
-      rewriter
-          .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
-          .getResult(0);
-  MemRefDescriptor memRefDescriptor(llvmStruct);
-
-  Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
-  operands.push_back(ptr);
-
-  return operands;
+static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
+                              const LLVMTypeConverter &typeConverter,
+                              RewriterBase &rewriter) {
+  MemRefDescriptor memRefDescriptor(buffer);
+  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
 }
 
 LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
 }
 
 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
   auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
 
-  auto opType = getA().getType();
+  auto opType = adaptor.getA().getType();
   Value src;
-  if (getSrc()) {
-    src = getSrc();
-  } else if (getConstantSrc()) {
-    src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
+  if (adaptor.getSrc()) {
+    src = adaptor.getSrc();
+  } else if (adaptor.getConstantSrc()) {
+    src = rewriter.create<LLVM::ConstantOp>(loc, opType,
+                                            adaptor.getConstantSrcAttr());
   } else {
     auto zeroAttr = rewriter.getZeroAttr(opType);
     src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
   }
 
-  return SmallVector<Value>{getA(), src, getK()};
+  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
 }
 
 SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
-                                       const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands(getOperands());
+x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                       const LLVMTypeConverter &typeConverter,
+                                       RewriterBase &rewriter) {
+  SmallVector<Value> intrinsicOperands(operands);
   // Dot product of all elements, broadcasted to all elements.
   Value scale =
       rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
-  operands.push_back(scale);
+  intrinsicOperands.push_back(scale);
 
-  return operands;
+  return intrinsicOperands;
 }
 
 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9ee44a63ba2e4..483c1f5c3e4c6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
 /// Generic one-to-one conversion of simply mappable operations into calls
 /// to their respective LLVM intrinsics.
 struct OneToOneIntrinsicOpConversion
-    : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
-  using OpInterfaceRewritePattern<
-      x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
+    : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+  using OpInterfaceConversionPattern<
+      x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
 
   OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
                                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
+      : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+                                     benefit),
         typeConverter(typeConverter) {}
 
-  LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
-                                PatternRewriter &rewriter) const override {
-    return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
-                            op.getIntrinsicOperands(rewriter, typeConverter),
-                            typeConverter, rewriter);
+  LogicalResult
+  matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    return intrinsicRewrite(
+        op, rewriter.getStringAttr(op.getIntrinsicName()),
+        op.getIntrinsicOperands(operands, typeConverter, rewriter),
+        typeConverter, rewriter);
   }
 
 private:



More information about the Mlir-commits mailing list