[Mlir-commits] [mlir] [mlir][LLVM] Support `immargs` in LLVM_IntrOpBase intrinsics (PR #73013)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Nov 21 08:51:12 PST 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/73013

This extends `LLVM_IntrOpBase` so that it can be passed a list of `immArgPositions` and a list (of the same length) of `immArgAttrNames`. `immArgPositions` contains the positions of `immargs` on the LLVM IR intrinsic, and `immArgAttrNames` maps those to a corresponding MLIR attribute.

This allows modeling LLVM `immargs` as MLIR attributes, which is the closest match semantically (and had already been done manually for the LLVM dialect intrinsics).

This has two upsides:
  * It's slightly easier to implement intrinsics with immargs now (especially if they make use of other features, such as overloads)
  * It clearly defines that `immargs` should map to attributes, before there was no mention of `immargs` in LLVMOpBase.td, so implementing them was unclear

This works with other features of the `LLVM_IntrOpBase`, so `immargs` can be marked as overloaded too (which is used in some intrinsics).

As part of this patch (and to test correctness) existing intrinsics have been updated to use these new parameters.

This also uncovered a few issues with the
`llvm.intr.vector.insert/extract` intrinsics. First, the argument order for insert did not match the LLVM intrinsic, and secondly, both were missing a mlirBuilder (so failed to import from LLVM IR). This is corrected with this patch (and a test case added).

>From ac45f83e1fc8936dbc6ee2bf4c49894e93aee062 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 21 Nov 2023 12:57:16 +0000
Subject: [PATCH] [mlir][LLVM] Support `immargs` in LLVM_IntrOpBase intrinsics

This extends `LLVM_IntrOpBase` so that it can be passed a list of
`immArgPositions` and a list (of the same length) of `immArgAttrNames`.
`immArgPositions` contains the positions of `immargs` on the LLVM IR
intrinsic, and `immArgAttrNames` maps those to a corresponding MLIR
attribute.

This allows modeling LLVM `immargs` as MLIR attributes, which is the
closest match semantically (and had already been done manually for the
LLVM dialect intrinsics).

This has two upsides:
  * It's slightly easier to implement intrinsics with immargs now
    (especially if they make use of other features, such as overloads)
  * It clearly defines that `immargs` should map to attributes, before
    there was no mention of `immargs` in LLVMOpBase.td, so implementing
    them was unclear

This works with other features of the `LLVM_IntrOpBase`, so `immargs`
can be marked as overloaded too (which is used in some intrinsics).

As part of this patch (and to test correctness) existing intrinsics have
been updated to use these new parameters.

This also uncovered a few issues with the
`llvm.intr.vector.insert/extract` intrinsics. First, the argument order
for insert did not match the LLVM intrinsic, and secondly, both were
missing a mlirBuilder (so failed to import from LLVM IR). This is
corrected with this patch (and a test case added).
---
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   | 205 ++++--------------
 .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td |  83 +++----
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  15 ++
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  11 +
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |   2 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  39 ++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  49 +++++
 mlir/test/Target/LLVMIR/Import/intrinsic.ll   |  16 ++
 8 files changed, 213 insertions(+), 207 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 1d6936ed6c2bf0f..fc088eacfc49744 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -86,54 +86,24 @@ class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
 
 class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> :
     LLVM_OneResultIntrOp<func, [], [0],
-           !listconcat([Pure], traits)> {
+           !listconcat([Pure], traits),
+            /*requiresFastmath=*/0,
+            /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
                    I1Attr:$is_zero_poison);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($is_zero_poison));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt1(op.getIsZeroPoison())}, }]
-      # declTypes # [{);
-    $res = inst;
-  }];
 }
 
-def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure]> {
+def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure],
+    /*requiresFastmath=*/0,
+    /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
                    I1Attr:$is_int_min_poison);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($is_int_min_poison));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt1(op.getIsIntMinPoison())}, }]
-      # declTypes # [{);
-    $res = inst;
-  }];
 }
 
-def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure]> {
+def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
+  /*requiresFastmath=*/0,
+  /*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($bit));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt32(op.getBit())},
-      }] # declTypes # [{);
-    $res = inst;
-  }];
 }
 
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">;
@@ -148,20 +118,11 @@ def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrOpF<"fmuladd">;
 def LLVM_Log10Op : LLVM_UnaryIntrOpF<"log10">;
 def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
 def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
-def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
+def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
+  /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+  /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
+> {
   let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $addr, $_int_attr($rw), $_int_attr($hint), $_int_attr($cache));
-  }];
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$addr, builder.getInt32(op.getRw()),
-       builder.getInt32(op.getHint()),
-       builder.getInt32(op.getCache())},
-      }] # declTypes # [{);
-  }];
 }
 def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
 def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">;
@@ -211,7 +172,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   AnySignlessInteger:$len, I1Attr:$isVolatile);
@@ -230,29 +192,18 @@ class LLVM_MemcpyIntrOpBase<string name> :
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $src, $len, $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $src, $len,
-       builder.getInt1(op.getIsVolatile())},
-      }] # declTypes # [{ ); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_MemcpyOp : LLVM_MemcpyIntrOpBase<"memcpy">;
 def LLVM_MemmoveOp : LLVM_MemcpyIntrOpBase<"memmove">;
 
 def LLVM_MemcpyInlineOp :
-    LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1],
+    LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2],
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   APIntAttr:$len, I1Attr:$isVolatile);
@@ -271,27 +222,14 @@ def LLVM_MemcpyInlineOp :
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $src, $_int_attr($len), $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $src, builder.getInt(op.getLen()),
-       builder.getInt1(op.getIsVolatile())}, { }]
-      # !interleave(!listconcat(declTypeList, [
-       [{ moduleTranslation.convertType(op.getLenAttr().getType()) }]
-       ]), ", ") # [{ }); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
   // Append the alias attributes defined by LLVM_IntrOpBase.
@@ -309,18 +247,6 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $val, $len, $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $val, $len,
-       builder.getInt1(op.getIsVolatile())},
-      }] # declTypes # [{ ); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_NoAliasScopeDeclOp
@@ -354,38 +280,16 @@ def LLVM_NoAliasScopeDeclOp
 
 /// Base operation for lifetime markers. The LLVM intrinsics require the size
 /// operand to be an immediate. In MLIR it is encoded as an attribute.
-class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [],
-    [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
+class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
+    [DeclareOpInterfaceMethods<PromotableOpInterface>],
+    /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+    /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
   let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
-
-  // Custom builder to convert the size attribute to an integer.
-  let llvmBuilder = [{
-    llvm::Module *module = builder.GetInsertBlock()->getModule();
-    llvm::Function *fn = llvm::Intrinsic::getDeclaration(
-        module, llvm::Intrinsic::}] # llvmEnumName # [{, {}] #
-        !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [0]>.lst, ", ")
-        # [{});
-    builder.CreateCall(fn, {builder.getInt64(op.getSizeAttr().getInt()),
-                            moduleTranslation.lookupValue(op.getPtr())});
-  }];
-
   let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
 }
 
-def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start"> {
-  // Custom builder to convert the size argument to an attribute.
-  string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::LifetimeStartOp>(
-      $_location, $_int_attr($size), $ptr);
-  }];
-}
-def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end"> {
-  // Custom builder to convert the size argument to an attribute.
-  string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::LifetimeEndOp>(
-      $_location, $_int_attr($size), $ptr);
-  }];
-}
+def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start">;
+def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">;
 
 // Intrinsics with multiple returns.
 
@@ -441,20 +345,12 @@ def LLVM_ExpectOp
 
 def LLVM_ExpectWithProbabilityOp
   : LLVM_OneResultIntrOp<"expect.with.probability", [], [0],
-                         [Pure, AllTypesMatch<["val", "expected", "res"]>]> {
+                         [Pure, AllTypesMatch<["val", "expected", "res"]>],
+                         /*requiresFastmath=*/0,
+                         /*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> {
   let arguments = (ins AnySignlessInteger:$val,
                        AnySignlessInteger:$expected,
                        F64Attr:$prob);
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::expect_with_probability,
-      {$val, $expected, llvm::ConstantFP::get(builder.getDoubleTy(), $prob)},
-      {$_resultType});
-  }];
-  string mlirBuilder = [{
-    $res = $_builder.create<LLVM::ExpectWithProbabilityOp>(
-      $_location, $val, $expected, $_float_attr($prob));
-  }];
   let assemblyFormat = "$val `,` $expected `,` $prob attr-dict `:` type($val)";
 }
 
@@ -962,16 +858,11 @@ def LLVM_Trap : LLVM_ZeroResultIntrOp<"trap">;
 
 def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
 
-def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap"> {
+def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
+  /*overloadedOperands=*/[], /*traits=*/[],
+  /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+  /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
   let arguments = (ins I8Attr:$failureKind);
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::ubsantrap, {builder.getInt8($failureKind)});
-  }];
-  string mlirBuilder = [{
-    $_op =
-      $_builder.create<LLVM::UBSanTrap>($_location, $_int_attr($failureKind));
-  }];
 }
 
 /// Create a call to vscale intrinsic.
@@ -987,23 +878,21 @@ def LLVM_StepVectorOp
 
 /// Create a call to vector.insert intrinsic
 def LLVM_vector_insert
-    : LLVM_Op<"intr.vector.insert",
-                 [Pure, AllTypesMatch<["dstvec", "res"]>,
+    : LLVM_OneResultIntrOp<"vector.insert",
+                  /*overloadedResults=*/[0], /*overloadedOperands=*/[1],
+                  /*traits=*/[Pure, AllTypesMatch<["dstvec", "res"]>,
                   PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
                     CPred<"getSrcVectorBitWidth() <= 131072">,
                     CPred<"getDstVectorBitWidth() <= 131072">
                   ]>>,
                   PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
                     CPred<"!isScalableVectorType($srcvec.getType()) || "
-                          "isScalableVectorType($dstvec.getType())">>]> {
-  let arguments = (ins LLVM_AnyVector:$srcvec, LLVM_AnyVector:$dstvec,
+                          "isScalableVectorType($dstvec.getType())">>],
+                  /*requiresFastmath=*/0,
+                  /*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> {
+  let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec,
                        I64Attr:$pos);
   let results = (outs LLVM_AnyVector:$res);
-  let builders = [LLVM_OneResultOpBuilder];
-  string llvmBuilder = [{
-    $res = builder.CreateInsertVector(
-        $_resultType, $dstvec, $srcvec, builder.getInt64($pos));
-  }];
   let assemblyFormat = "$srcvec `,` $dstvec `[` $pos `]` attr-dict `:` "
     "type($srcvec) `into` type($res)";
   let extraClassDeclaration = [{
@@ -1022,22 +911,20 @@ def LLVM_vector_insert
 
 /// Create a call to vector.extract intrinsic
 def LLVM_vector_extract
-    : LLVM_Op<"intr.vector.extract",
-                 [Pure,
+    : LLVM_OneResultIntrOp<"vector.extract",
+                 /*overloadedResults=*/[0], /*overloadedOperands=*/[0],
+                 /*traits=*/[Pure,
                   PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
                     CPred<"getSrcVectorBitWidth() <= 131072">,
                     CPred<"getResVectorBitWidth() <= 131072">
                   ]>>,
                   PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
                     CPred<"!isScalableVectorType($res.getType()) || "
-                          "isScalableVectorType($srcvec.getType())">>]> {
+                          "isScalableVectorType($srcvec.getType())">>],
+                  /*requiresFastmath=*/0,
+                  /*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> {
   let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
   let results = (outs LLVM_AnyVector:$res);
-  let builders = [LLVM_OneResultOpBuilder];
-  string llvmBuilder = [{
-    $res = builder.CreateExtractVector(
-        $_resultType, $srcvec, builder.getInt64($pos));
-  }];
   let assemblyFormat = "$srcvec `[` $pos `]` attr-dict `:` "
     "type($res) `from` type($srcvec)";
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 4e42a0e46d9bf9c..3ce1d27748302e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -231,27 +231,6 @@ class LLVM_MemOpPatterns {
   }];
 }
 
-// Patterns with code obtaining the LLVM IR type of the given operand or result
-// of operation. "$0" is expected to be replaced by the position of the operand
-// or result in the operation.
-def LLVM_IntrPatterns {
-  string operand =
-    [{moduleTranslation.convertType(opInst.getOperand($0).getType())}];
-  string result =
-    [{moduleTranslation.convertType(opInst.getResult($0).getType())}];
-  string structResult =
-    [{moduleTranslation.convertType(
-        ::llvm::cast<LLVM::LLVMStructType>(opInst.getResult(0).getType())
-              .getBody()[$0])}];
-}
-
-// For every value in the list, substitutes the value in the place of "$0" in
-// "pattern" and stores the list of strings as "lst".
-class ListIntSubst<string pattern, list<int> values> {
-  list<string> lst = !foreach(x, values,
-                              !subst("$0", !cast<string>(x), pattern));
-}
-
 //===----------------------------------------------------------------------===//
 // Base classes for LLVM dialect operations.
 //===----------------------------------------------------------------------===//
@@ -300,7 +279,9 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
                       bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0>
+                      bit requiresFastmath = 0,
+                      list<int> immArgPositions = [],
+                      list<string> immArgAttrNames = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
         !if(!gt(requiresAccessGroup, 0),
             [DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -320,37 +301,38 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                  OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
                  OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
             (ins )));
-  string resultPattern = !if(!gt(numResults, 1),
-                             LLVM_IntrPatterns.structResult,
-                             LLVM_IntrPatterns.result);
   string llvmEnumName = enumName;
-  list<string> declTypeList = !listconcat(
-            ListIntSubst<resultPattern, overloadedResults>.lst,
-            ListIntSubst<LLVM_IntrPatterns.operand,
-                         overloadedOperands>.lst);
-  string declTypes = [{ { }] # !interleave(declTypeList, ", ") # [{ } }];
+  string overloadedResultsCpp =  "{" # !interleave(overloadedResults, ", ") # "}";
+  string overloadedOperandsCpp =  "{" # !interleave(overloadedOperands, ", ") # "}";
+  string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
+  string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
+    "StringLiteral(\"" # name # "\")"), ", ") # "}";
   let llvmBuilder = [{
-    llvm::Module *module = builder.GetInsertBlock()->getModule();
-    llvm::Function *fn = llvm::Intrinsic::getDeclaration(
-        module,
-        llvm::Intrinsic::}] # enumName # [{,}] # declTypes # [{);
-    auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-    }] # [{
-    auto *inst = builder.CreateCall(fn, operands);
+    auto *inst = LLVM::detail::createIntrinsicCall(
+      builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
+        enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
+        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
     (void) inst;
     }] # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
        # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
        # !if(!gt(numResults, 0), "$res = inst;", "");
 
   string mlirBuilder = [{
-    FailureOr<SmallVector<Value>> mlirOperands =
-      moduleImport.convertValues(llvmOperands);
-    if (failed(mlirOperands))
+    SmallVector<Value> mlirOperands;
+    SmallVector<NamedAttribute> mlirAttrs;
+    if (failed(LLVM::detail::convertIntrinsicArguments(moduleImport,
+      llvmOperands,
+      }] # immArgPositionsCpp # [{,
+      }] # immArgAttrNamesCpp # [{,
+      mlirOperands,
+      mlirAttrs))
+    ) {
       return failure();
+    }
     SmallVector<Type> resultTypes =
     }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
     auto op = $_builder.create<$_qualCppClassName>(
-      $_location, resultTypes, *mlirOperands);
+      $_location, resultTypes, mlirOperands, mlirAttrs);
     }] # !if(!gt(requiresFastmath, 0),
       "moduleImport.setFastmathFlagsAttr(inst, op);", "")
     # !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
@@ -361,11 +343,13 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
                   int numResults, bit requiresAccessGroup = 0,
-                  bit requiresAliasAnalysis = 0, bit requiresFastmath = 0>
+                  bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                  list<int> immArgPositions = [],
+                  list<string> immArgAttrNames = []>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
                       overloadedResults, overloadedOperands, traits,
                       numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath>;
+                      requiresFastmath, immArgPositions, immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -384,9 +368,12 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
 class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             list<Trait> traits = [],
                             bit requiresAccessGroup = 0,
-                            bit requiresAliasAnalysis = 0>
+                            bit requiresAliasAnalysis = 0,
+                            list<int> immArgPositions = [],
+                            list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
-                  requiresAccessGroup, requiresAliasAnalysis>;
+                  requiresAccessGroup, requiresAliasAnalysis,
+                  /*requiresFastMath=*/0, immArgPositions, immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -397,10 +384,12 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
 class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> overloadedOperands = [],
                            list<Trait> traits = [],
-                           bit requiresFastmath = 0>
+                           bit requiresFastmath = 0,
+                          list<int> immArgPositions = [],
+                          list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-                  requiresFastmath>;
+                  requiresFastmath, immArgPositions, immArgAttrNames>;
 
 def LLVM_OneResultOpBuilder :
   OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 5f5adbff6c04ef8..042c153189b114f 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -380,6 +380,21 @@ class ModuleImport {
   bool emitExpensiveWarnings;
 };
 
+namespace detail {
+
+/// Converts the LLVM values for an intrinsic to mixed MLIR values and
+/// attributes. Attributes correspond to LLVM immargs. The list
+/// `immArgPositions` contains the positions of immargs on the LLVM intrinsic,
+/// and `immArgAttrNames` list (of the same length) contains the corresponding
+/// MLIR attribute names.
+LogicalResult convertIntrinsicArguments(
+    ModuleImport &moduleInport, ArrayRef<llvm::Value *> values,
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    SmallVectorImpl<Value> &valuesOut,
+    SmallVectorImpl<NamedAttribute> &attrsOut);
+
+} // namespace detail
+
 } // namespace LLVM
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index f9026f84935be52..4820e826d0ca357 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -394,6 +394,17 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder,
                                     llvm::Intrinsic::ID intrinsic,
                                     ArrayRef<llvm::Value *> args = {},
                                     ArrayRef<llvm::Type *> tys = {});
+
+/// Creates a call to a LLVM IR intrinsic defined by LLVM_IntrOpBase. This
+/// resolves the overloads, and maps mixed MLIR value and attribute arguments to
+/// LLVM values.
+llvm::CallInst *createIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
+    ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
+    ArrayRef<unsigned> immArgPositions,
+    ArrayRef<StringLiteral> immArgAttrNames);
+
 } // namespace detail
 
 } // namespace LLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 75a35b4c801e4a5..cd5df0be740b9c0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1276,7 +1276,7 @@ struct VectorScalableInsertOpLowering
   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
-        insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
+        insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
     return success();
   }
 };
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d9e039e75e9ef24..b749f097e11ffc2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1199,6 +1199,45 @@ ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
   return remapped;
 }
 
+LogicalResult mlir::LLVM::detail::convertIntrinsicArguments(
+    ModuleImport &moduleInport, ArrayRef<llvm::Value *> values,
+    ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+    SmallVectorImpl<Value> &valuesOut,
+    SmallVectorImpl<NamedAttribute> &attrsOut) {
+
+  assert(immArgPositions.size() == immArgAttrNames.size() &&
+         "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
+         "length");
+
+  auto maybeMlirValues = moduleInport.convertValues(values);
+  if (failed(maybeMlirValues))
+    return failure();
+
+  auto mlirValues = std::move(maybeMlirValues).value();
+  for (auto [immArgPos, immArgName] :
+       llvm::zip(immArgPositions, immArgAttrNames)) {
+    IntegerAttr integerAttr;
+    FloatAttr floatAttr;
+    Value &value = mlirValues[immArgPos];
+    auto nameAttr = StringAttr::get(value.getContext(), immArgName);
+    if (matchPattern(value, m_Constant(&integerAttr)))
+      attrsOut.push_back({nameAttr, integerAttr});
+    else if (matchPattern(value, m_Constant(&floatAttr)))
+      attrsOut.push_back({nameAttr, floatAttr});
+    else {
+      assert("expected immarg to be float or integer constant");
+      return failure();
+    }
+    // Mark matched attribute values as null (so they can be removed below).
+    value = nullptr;
+  }
+
+  llvm::copy_if(mlirValues, std::back_inserter(valuesOut),
+                [](Value value) { return bool(value); });
+
+  return success();
+}
+
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 911c7141e45d5f2..b9272273c39f03f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -580,6 +580,55 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(fn, args);
 }
 
+llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
+    ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
+    ArrayRef<unsigned> immArgPositions,
+    ArrayRef<StringLiteral> immArgAttrNames) {
+  assert(immArgPositions.size() == immArgAttrNames.size() &&
+         "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
+         "length");
+
+  // Map operands and attributes to LLVM values.
+  auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
+  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+  for (auto [immArgPos, immArgName] :
+       llvm::zip(immArgPositions, immArgAttrNames)) {
+    auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
+    assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
+    auto *type = moduleTranslation.convertType(attr.getType());
+    args[immArgPos] = LLVM::detail::getLLVMConstant(
+        type, attr, intrOp->getLoc(), moduleTranslation);
+  }
+  unsigned opArg = 0;
+  for (auto &arg : args) {
+    if (!arg)
+      arg = operands[opArg++];
+  }
+
+  // Resolve overloaded intrinsic declaration.
+  SmallVector<llvm::Type *> overloadedTypes;
+  for (unsigned overloadedResultIdx : overloadedResults) {
+    if (numResults > 1) {
+      // More than one result is mapped to an LLVM struct.
+      overloadedTypes.push_back(moduleTranslation.convertType(
+          llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType())
+              .getBody()[overloadedResultIdx]));
+    } else {
+      overloadedTypes.push_back(
+          moduleTranslation.convertType(intrOp->getResult(0).getType()));
+    }
+  }
+  for (unsigned overloadedOperandIdx : overloadedOperands)
+    overloadedTypes.push_back(args[overloadedOperandIdx]->getType());
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  llvm::Function *llvmIntr =
+      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+
+  return builder.CreateCall(llvmIntr, args);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 8ce16fe5705cb86..1efadcb88b8986c 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -755,6 +755,20 @@ define void @lifetime(ptr %0) {
   ret void
 }
 
+; CHECK-LABEL: llvm.func @vector_insert
+define void @vector_insert(<vscale x 4 x float> %0, <4 x float> %1) {
+  ; CHECK: llvm.intr.vector.insert %{{.*}}, %{{.*}}[4] : vector<4xf32> into !llvm.vec<? x 4 x  f32>
+  %3 = call <vscale x 4 x float>  @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> %0, <4 x float> %1, i64 4);
+  ret void
+}
+
+; CHECK-LABEL: llvm.func @vector_extract
+define void @vector_extract(<vscale x 4 x float> %0) {
+  ; llvm.intr.vector.extract %{{.*}}[0] : vector<4xf32> from !llvm.vec<? x 4 x  f32>
+  %2 = call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %0, i64 0);
+  ret void
+}
+
 ; CHECK-LABEL:  llvm.func @vector_predication_intrinsics
 define void @vector_predication_intrinsics(<8 x i32> %0, <8 x i32> %1, <8 x float> %2, <8 x float> %3, <8 x i64> %4, <8 x double> %5, <8 x ptr> %6, i32 %7, float %8, ptr %9, ptr %10, <8 x i1> %11, i32 %12) {
   ; CHECK: "llvm.intr.vp.add"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
@@ -1085,3 +1099,5 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
 declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
 declare void @llvm.assume(i1)
 declare float @llvm.ssa.copy.f32(float returned)
+declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
+declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)



More information about the Mlir-commits mailing list