[Mlir-commits] [mlir] dbb8643 - [mlir][LLVM] Support `immargs` in LLVM_IntrOpBase intrinsics (#73013)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 23 02:12:17 PST 2023


Author: Benjamin Maxwell
Date: 2023-11-23T10:12:12Z
New Revision: dbb86433333154226233da46271e8c520e8a5119

URL: https://github.com/llvm/llvm-project/commit/dbb86433333154226233da46271e8c520e8a5119
DIFF: https://github.com/llvm/llvm-project/commit/dbb86433333154226233da46271e8c520e8a5119.diff

LOG: [mlir][LLVM] Support `immargs` in LLVM_IntrOpBase intrinsics (#73013)

This extends `LLVM_IntrOpBase` so that it can be passed a list of
`immArgPositions` and a list (of the same length) of `immArgAttrNames`.
`immArgPositions` contains the positions of `immargs` on the LLVM IR
intrinsic, and `immArgAttrNames` maps those to a corresponding MLIR
attribute.

This allows modeling LLVM `immargs` as MLIR attributes, which is the
closest match semantically (and had already been done manually for the
LLVM dialect intrinsics).

This has two upsides:
* It's slightly easier to implement intrinsics with immargs now
(especially if they make use of other features, such as overloads)
* It clearly defines that `immargs` should map to attributes, before
there was no mention of `immargs` in LLVMOpBase.td, so implementing them
was unclear

This works with other features of the `LLVM_IntrOpBase`, so `immargs`
can be marked as overloaded too (which is used in some intrinsics).

As part of this patch (and to test correctness) existing intrinsics have
been updated to use these new parameters.

This also uncovered a few issues with the
`llvm.intr.vector.insert/extract` intrinsics. First, the argument order
for insert did not match the LLVM intrinsic, and secondly, both were
missing a mlirBuilder (so failed to import from LLVM IR). This is
corrected with this patch (and a test case added).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/Import/intrinsic.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 1d6936ed6c2bf0f..fc088eacfc49744 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -86,54 +86,24 @@ class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
 
 class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> :
     LLVM_OneResultIntrOp<func, [], [0],
-           !listconcat([Pure], traits)> {
+           !listconcat([Pure], traits),
+            /*requiresFastmath=*/0,
+            /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
                    I1Attr:$is_zero_poison);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($is_zero_poison));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt1(op.getIsZeroPoison())}, }]
-      # declTypes # [{);
-    $res = inst;
-  }];
 }
 
-def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure]> {
+def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure],
+    /*requiresFastmath=*/0,
+    /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
                    I1Attr:$is_int_min_poison);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($is_int_min_poison));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt1(op.getIsIntMinPoison())}, }]
-      # declTypes # [{);
-    $res = inst;
-  }];
 }
 
-def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure]> {
+def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
+  /*requiresFastmath=*/0,
+  /*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
-  string mlirBuilder = [{
-    auto op = $_builder.create<$_qualCppClassName>($_location,
-      $_resultType, $in, $_int_attr($bit));
-      $res = op;
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$in, builder.getInt32(op.getBit())},
-      }] # declTypes # [{);
-    $res = inst;
-  }];
 }
 
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">;
@@ -148,20 +118,11 @@ def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrOpF<"fmuladd">;
 def LLVM_Log10Op : LLVM_UnaryIntrOpF<"log10">;
 def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
 def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
-def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
+def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
+  /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+  /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
+> {
   let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $addr, $_int_attr($rw), $_int_attr($hint), $_int_attr($cache));
-  }];
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$addr, builder.getInt32(op.getRw()),
-       builder.getInt32(op.getHint()),
-       builder.getInt32(op.getCache())},
-      }] # declTypes # [{);
-  }];
 }
 def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
 def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">;
@@ -211,7 +172,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   AnySignlessInteger:$len, I1Attr:$isVolatile);
@@ -230,29 +192,18 @@ class LLVM_MemcpyIntrOpBase<string name> :
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $src, $len, $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $src, $len,
-       builder.getInt1(op.getIsVolatile())},
-      }] # declTypes # [{ ); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_MemcpyOp : LLVM_MemcpyIntrOpBase<"memcpy">;
 def LLVM_MemmoveOp : LLVM_MemcpyIntrOpBase<"memmove">;
 
 def LLVM_MemcpyInlineOp :
-    LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1],
+    LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2],
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   APIntAttr:$len, I1Attr:$isVolatile);
@@ -271,27 +222,14 @@ def LLVM_MemcpyInlineOp :
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $src, $_int_attr($len), $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $src, builder.getInt(op.getLen()),
-       builder.getInt1(op.getIsVolatile())}, { }]
-      # !interleave(!listconcat(declTypeList, [
-       [{ moduleTranslation.convertType(op.getLenAttr().getType()) }]
-       ]), ", ") # [{ }); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
     [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
-    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
   // Append the alias attributes defined by LLVM_IntrOpBase.
@@ -309,18 +247,6 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
             /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
     }]>
   ];
-  string mlirBuilder = [{
-    $_op = $_builder.create<$_qualCppClassName>($_location,
-      $dst, $val, $len, $_int_attr($isVolatile));
-  }];
-  string llvmBuilder = [{
-    auto *inst = createIntrinsicCall(
-      builder, llvm::Intrinsic::}] # llvmEnumName # [{,
-      {$dst, $val, $len,
-       builder.getInt1(op.getIsVolatile())},
-      }] # declTypes # [{ ); }]
-      # setAccessGroupsMetadataCode
-      # setAliasAnalysisMetadataCode;
 }
 
 def LLVM_NoAliasScopeDeclOp
@@ -354,38 +280,16 @@ def LLVM_NoAliasScopeDeclOp
 
 /// Base operation for lifetime markers. The LLVM intrinsics require the size
 /// operand to be an immediate. In MLIR it is encoded as an attribute.
-class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [],
-    [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
+class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
+    [DeclareOpInterfaceMethods<PromotableOpInterface>],
+    /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+    /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
   let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
-
-  // Custom builder to convert the size attribute to an integer.
-  let llvmBuilder = [{
-    llvm::Module *module = builder.GetInsertBlock()->getModule();
-    llvm::Function *fn = llvm::Intrinsic::getDeclaration(
-        module, llvm::Intrinsic::}] # llvmEnumName # [{, {}] #
-        !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [0]>.lst, ", ")
-        # [{});
-    builder.CreateCall(fn, {builder.getInt64(op.getSizeAttr().getInt()),
-                            moduleTranslation.lookupValue(op.getPtr())});
-  }];
-
   let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
 }
 
-def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start"> {
-  // Custom builder to convert the size argument to an attribute.
-  string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::LifetimeStartOp>(
-      $_location, $_int_attr($size), $ptr);
-  }];
-}
-def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end"> {
-  // Custom builder to convert the size argument to an attribute.
-  string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::LifetimeEndOp>(
-      $_location, $_int_attr($size), $ptr);
-  }];
-}
+def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start">;
+def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">;
 
 // Intrinsics with multiple returns.
 
@@ -441,20 +345,12 @@ def LLVM_ExpectOp
 
 def LLVM_ExpectWithProbabilityOp
   : LLVM_OneResultIntrOp<"expect.with.probability", [], [0],
-                         [Pure, AllTypesMatch<["val", "expected", "res"]>]> {
+                         [Pure, AllTypesMatch<["val", "expected", "res"]>],
+                         /*requiresFastmath=*/0,
+                         /*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> {
   let arguments = (ins AnySignlessInteger:$val,
                        AnySignlessInteger:$expected,
                        F64Attr:$prob);
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::expect_with_probability,
-      {$val, $expected, llvm::ConstantFP::get(builder.getDoubleTy(), $prob)},
-      {$_resultType});
-  }];
-  string mlirBuilder = [{
-    $res = $_builder.create<LLVM::ExpectWithProbabilityOp>(
-      $_location, $val, $expected, $_float_attr($prob));
-  }];
   let assemblyFormat = "$val `,` $expected `,` $prob attr-dict `:` type($val)";
 }
 
@@ -962,16 +858,11 @@ def LLVM_Trap : LLVM_ZeroResultIntrOp<"trap">;
 
 def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
 
-def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap"> {
+def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
+  /*overloadedOperands=*/[], /*traits=*/[],
+  /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+  /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
   let arguments = (ins I8Attr:$failureKind);
-  string llvmBuilder = [{
-    createIntrinsicCall(
-      builder, llvm::Intrinsic::ubsantrap, {builder.getInt8($failureKind)});
-  }];
-  string mlirBuilder = [{
-    $_op =
-      $_builder.create<LLVM::UBSanTrap>($_location, $_int_attr($failureKind));
-  }];
 }
 
 /// Create a call to vscale intrinsic.
@@ -987,23 +878,21 @@ def LLVM_StepVectorOp
 
 /// Create a call to vector.insert intrinsic
 def LLVM_vector_insert
-    : LLVM_Op<"intr.vector.insert",
-                 [Pure, AllTypesMatch<["dstvec", "res"]>,
+    : LLVM_OneResultIntrOp<"vector.insert",
+                  /*overloadedResults=*/[0], /*overloadedOperands=*/[1],
+                  /*traits=*/[Pure, AllTypesMatch<["dstvec", "res"]>,
                   PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
                     CPred<"getSrcVectorBitWidth() <= 131072">,
                     CPred<"getDstVectorBitWidth() <= 131072">
                   ]>>,
                   PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
                     CPred<"!isScalableVectorType($srcvec.getType()) || "
-                          "isScalableVectorType($dstvec.getType())">>]> {
-  let arguments = (ins LLVM_AnyVector:$srcvec, LLVM_AnyVector:$dstvec,
+                          "isScalableVectorType($dstvec.getType())">>],
+                  /*requiresFastmath=*/0,
+                  /*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> {
+  let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec,
                        I64Attr:$pos);
   let results = (outs LLVM_AnyVector:$res);
-  let builders = [LLVM_OneResultOpBuilder];
-  string llvmBuilder = [{
-    $res = builder.CreateInsertVector(
-        $_resultType, $dstvec, $srcvec, builder.getInt64($pos));
-  }];
   let assemblyFormat = "$srcvec `,` $dstvec `[` $pos `]` attr-dict `:` "
     "type($srcvec) `into` type($res)";
   let extraClassDeclaration = [{
@@ -1022,22 +911,20 @@ def LLVM_vector_insert
 
 /// Create a call to vector.extract intrinsic
 def LLVM_vector_extract
-    : LLVM_Op<"intr.vector.extract",
-                 [Pure,
+    : LLVM_OneResultIntrOp<"vector.extract",
+                 /*overloadedResults=*/[0], /*overloadedOperands=*/[0],
+                 /*traits=*/[Pure,
                   PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
                     CPred<"getSrcVectorBitWidth() <= 131072">,
                     CPred<"getResVectorBitWidth() <= 131072">
                   ]>>,
                   PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
                     CPred<"!isScalableVectorType($res.getType()) || "
-                          "isScalableVectorType($srcvec.getType())">>]> {
+                          "isScalableVectorType($srcvec.getType())">>],
+                  /*requiresFastmath=*/0,
+                  /*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> {
   let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
   let results = (outs LLVM_AnyVector:$res);
-  let builders = [LLVM_OneResultOpBuilder];
-  string llvmBuilder = [{
-    $res = builder.CreateExtractVector(
-        $_resultType, $srcvec, builder.getInt64($pos));
-  }];
   let assemblyFormat = "$srcvec `[` $pos `]` attr-dict `:` "
     "type($res) `from` type($srcvec)";
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 4e42a0e46d9bf9c..2ce8d8df05edc42 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -231,27 +231,6 @@ class LLVM_MemOpPatterns {
   }];
 }
 
-// Patterns with code obtaining the LLVM IR type of the given operand or result
-// of operation. "$0" is expected to be replaced by the position of the operand
-// or result in the operation.
-def LLVM_IntrPatterns {
-  string operand =
-    [{moduleTranslation.convertType(opInst.getOperand($0).getType())}];
-  string result =
-    [{moduleTranslation.convertType(opInst.getResult($0).getType())}];
-  string structResult =
-    [{moduleTranslation.convertType(
-        ::llvm::cast<LLVM::LLVMStructType>(opInst.getResult(0).getType())
-              .getBody()[$0])}];
-}
-
-// For every value in the list, substitutes the value in the place of "$0" in
-// "pattern" and stores the list of strings as "lst".
-class ListIntSubst<string pattern, list<int> values> {
-  list<string> lst = !foreach(x, values,
-                              !subst("$0", !cast<string>(x), pattern));
-}
-
 //===----------------------------------------------------------------------===//
 // Base classes for LLVM dialect operations.
 //===----------------------------------------------------------------------===//
@@ -295,12 +274,19 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
 // interfaces the intrinsic implements. If the corresponding flags are set, the
 // "aliasAttrs" list contains the arguments required by the access group and
 // alias analysis interfaces. Derived intrinsics should append the "aliasAttrs"
-// to their argument list if they set one of the flags.
+// to their argument list if they set one of the flags. LLVM `immargs` can be
+// represented as MLIR attributes by providing both the `immArgPositions` and
+// `immArgAttrNames` lists. These two lists should have equal length, with
+// `immArgPositions` containing the argument positions on the LLVM IR attribute
+// that are `immargs`, and `immArgAttrNames` mapping these to corresponding
+// MLIR attributes.
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
                       bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0>
+                      bit requiresFastmath = 0,
+                      list<int> immArgPositions = [],
+                      list<string> immArgAttrNames = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
         !if(!gt(requiresAccessGroup, 0),
             [DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -320,37 +306,38 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                  OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
                  OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
             (ins )));
-  string resultPattern = !if(!gt(numResults, 1),
-                             LLVM_IntrPatterns.structResult,
-                             LLVM_IntrPatterns.result);
   string llvmEnumName = enumName;
-  list<string> declTypeList = !listconcat(
-            ListIntSubst<resultPattern, overloadedResults>.lst,
-            ListIntSubst<LLVM_IntrPatterns.operand,
-                         overloadedOperands>.lst);
-  string declTypes = [{ { }] # !interleave(declTypeList, ", ") # [{ } }];
+  string overloadedResultsCpp =  "{" # !interleave(overloadedResults, ", ") # "}";
+  string overloadedOperandsCpp =  "{" # !interleave(overloadedOperands, ", ") # "}";
+  string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
+  string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
+    "StringLiteral(\"" # name # "\")"), ", ") # "}";
   let llvmBuilder = [{
-    llvm::Module *module = builder.GetInsertBlock()->getModule();
-    llvm::Function *fn = llvm::Intrinsic::getDeclaration(
-        module,
-        llvm::Intrinsic::}] # enumName # [{,}] # declTypes # [{);
-    auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-    }] # [{
-    auto *inst = builder.CreateCall(fn, operands);
+    auto *inst = LLVM::detail::createIntrinsicCall(
+      builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
+        enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
+        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
     (void) inst;
     }] # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
        # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
        # !if(!gt(numResults, 0), "$res = inst;", "");
 
   string mlirBuilder = [{
-    FailureOr<SmallVector<Value>> mlirOperands =
-      moduleImport.convertValues(llvmOperands);
-    if (failed(mlirOperands))
+    SmallVector<Value> mlirOperands;
+    SmallVector<NamedAttribute> mlirAttrs;
+    if (failed(moduleImport.convertIntrinsicArguments(
+      llvmOperands,
+      }] # immArgPositionsCpp # [{,
+      }] # immArgAttrNamesCpp # [{,
+      mlirOperands,
+      mlirAttrs))
+    ) {
       return failure();
+    }
     SmallVector<Type> resultTypes =
     }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
     auto op = $_builder.create<$_qualCppClassName>(
-      $_location, resultTypes, *mlirOperands);
+      $_location, resultTypes, mlirOperands, mlirAttrs);
     }] # !if(!gt(requiresFastmath, 0),
       "moduleImport.setFastmathFlagsAttr(inst, op);", "")
     # !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
@@ -361,11 +348,13 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
                   int numResults, bit requiresAccessGroup = 0,
-                  bit requiresAliasAnalysis = 0, bit requiresFastmath = 0>
+                  bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                  list<int> immArgPositions = [],
+                  list<string> immArgAttrNames = []>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
                       overloadedResults, overloadedOperands, traits,
                       numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath>;
+                      requiresFastmath, immArgPositions, immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -384,9 +373,12 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
 class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             list<Trait> traits = [],
                             bit requiresAccessGroup = 0,
-                            bit requiresAliasAnalysis = 0>
+                            bit requiresAliasAnalysis = 0,
+                            list<int> immArgPositions = [],
+                            list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
-                  requiresAccessGroup, requiresAliasAnalysis>;
+                  requiresAccessGroup, requiresAliasAnalysis,
+                  /*requiresFastMath=*/0, immArgPositions, immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -397,10 +389,12 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
 class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> overloadedOperands = [],
                            list<Trait> traits = [],
-                           bit requiresFastmath = 0>
+                           bit requiresFastmath = 0,
+                          list<int> immArgPositions = [],
+                          list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-                  requiresFastmath>;
+                  requiresFastmath, immArgPositions, immArgAttrNames>;
 
 def LLVM_OneResultOpBuilder :
   OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 5f5adbff6c04ef8..b8e449dc11df155 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -214,6 +214,18 @@ class ModuleImport {
   /// after the function conversion has finished.
   void addDebugIntrinsic(llvm::CallInst *intrinsic);
 
+  /// Converts the LLVM values for an intrinsic to mixed MLIR values and
+  /// attributes for LLVM_IntrOpBase. Attributes correspond to LLVM immargs. The
+  /// list `immArgPositions` contains the positions of immargs on the LLVM
+  /// intrinsic, and `immArgAttrNames` list (of the same length) contains the
+  /// corresponding MLIR attribute names.
+  LogicalResult
+  convertIntrinsicArguments(ArrayRef<llvm::Value *> values,
+                            ArrayRef<unsigned> immArgPositions,
+                            ArrayRef<StringLiteral> immArgAttrNames,
+                            SmallVectorImpl<Value> &valuesOut,
+                            SmallVectorImpl<NamedAttribute> &attrsOut);
+
 private:
   /// Clears the accumulated state before processing a new region.
   void clearRegionState() {

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index f9026f84935be52..4820e826d0ca357 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -394,6 +394,17 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder,
                                     llvm::Intrinsic::ID intrinsic,
                                     ArrayRef<llvm::Value *> args = {},
                                     ArrayRef<llvm::Type *> tys = {});
+
+/// Creates a call to a LLVM IR intrinsic defined by LLVM_IntrOpBase. This
+/// resolves the overloads, and maps mixed MLIR value and attribute arguments to
+/// LLVM values.
+llvm::CallInst *createIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
+    ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
+    ArrayRef<unsigned> immArgPositions,
+    ArrayRef<StringLiteral> immArgAttrNames);
+
 } // namespace detail
 
 } // namespace LLVM

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 75a35b4c801e4a5..cd5df0be740b9c0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1276,7 +1276,7 @@ struct VectorScalableInsertOpLowering
   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
-        insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
+        insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
     return success();
   }
 };

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d9e039e75e9ef24..9cdc1f45d38fa59 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -700,7 +700,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
 
 /// Returns an integer or float attribute for the provided scalar constant
 /// `constScalar` or nullptr if the conversion fails.
-static Attribute getScalarConstantAsAttr(OpBuilder &builder,
+static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
                                          llvm::Constant *constScalar) {
   MLIRContext *context = builder.getContext();
 
@@ -1199,6 +1199,40 @@ ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
   return remapped;
 }
 
+LogicalResult ModuleImport::convertIntrinsicArguments(
+    ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions,
+    ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut,
+    SmallVectorImpl<NamedAttribute> &attrsOut) {
+  assert(immArgPositions.size() == immArgAttrNames.size() &&
+         "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
+         "length");
+
+  SmallVector<llvm::Value *> operands(values);
+  for (auto [immArgPos, immArgName] :
+       llvm::zip(immArgPositions, immArgAttrNames)) {
+    auto &value = operands[immArgPos];
+    auto *constant = llvm::cast<llvm::Constant>(value);
+    auto attr = getScalarConstantAsAttr(builder, constant);
+    assert(attr && attr.getType().isIntOrFloat() &&
+           "expected immarg to be float or integer constant");
+    auto nameAttr = StringAttr::get(attr.getContext(), immArgName);
+    attrsOut.push_back({nameAttr, attr});
+    // Mark matched attribute values as null (so they can be removed below).
+    value = nullptr;
+  }
+
+  for (llvm::Value *value : operands) {
+    if (!value)
+      continue;
+    auto mlirValue = convertValue(value);
+    if (failed(mlirValue))
+      return failure();
+    valuesOut.push_back(*mlirValue);
+  }
+
+  return success();
+}
+
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ef1c8c21d54b08f..245048ef0b93939 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -580,6 +580,55 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(fn, args);
 }
 
+llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
+    ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
+    ArrayRef<unsigned> immArgPositions,
+    ArrayRef<StringLiteral> immArgAttrNames) {
+  assert(immArgPositions.size() == immArgAttrNames.size() &&
+         "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
+         "length");
+
+  // Map operands and attributes to LLVM values.
+  auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
+  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+  for (auto [immArgPos, immArgName] :
+       llvm::zip(immArgPositions, immArgAttrNames)) {
+    auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
+    assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
+    auto *type = moduleTranslation.convertType(attr.getType());
+    args[immArgPos] = LLVM::detail::getLLVMConstant(
+        type, attr, intrOp->getLoc(), moduleTranslation);
+  }
+  unsigned opArg = 0;
+  for (auto &arg : args) {
+    if (!arg)
+      arg = operands[opArg++];
+  }
+
+  // Resolve overloaded intrinsic declaration.
+  SmallVector<llvm::Type *> overloadedTypes;
+  for (unsigned overloadedResultIdx : overloadedResults) {
+    if (numResults > 1) {
+      // More than one result is mapped to an LLVM struct.
+      overloadedTypes.push_back(moduleTranslation.convertType(
+          llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType())
+              .getBody()[overloadedResultIdx]));
+    } else {
+      overloadedTypes.push_back(
+          moduleTranslation.convertType(intrOp->getResult(0).getType()));
+    }
+  }
+  for (unsigned overloadedOperandIdx : overloadedOperands)
+    overloadedTypes.push_back(args[overloadedOperandIdx]->getType());
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  llvm::Function *llvmIntr =
+      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+
+  return builder.CreateCall(llvmIntr, args);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult

diff  --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 8ce16fe5705cb86..c8dcde11d93e645 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -465,7 +465,6 @@ define void @trap_intrinsics() {
 
 ; CHECK-LABEL:  llvm.func @memcpy_test
 define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
-  ; CHECK: %[[CST:.+]] = llvm.mlir.constant(10 : i64) : i64
   ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
   ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
@@ -755,6 +754,20 @@ define void @lifetime(ptr %0) {
   ret void
 }
 
+; CHECK-LABEL: llvm.func @vector_insert
+define void @vector_insert(<vscale x 4 x float> %0, <4 x float> %1) {
+  ; CHECK: llvm.intr.vector.insert %{{.*}}, %{{.*}}[4] : vector<4xf32> into !llvm.vec<? x 4 x  f32>
+  %3 = call <vscale x 4 x float>  @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> %0, <4 x float> %1, i64 4);
+  ret void
+}
+
+; CHECK-LABEL: llvm.func @vector_extract
+define void @vector_extract(<vscale x 4 x float> %0) {
+  ; llvm.intr.vector.extract %{{.*}}[0] : vector<4xf32> from !llvm.vec<? x 4 x  f32>
+  %2 = call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %0, i64 0);
+  ret void
+}
+
 ; CHECK-LABEL:  llvm.func @vector_predication_intrinsics
 define void @vector_predication_intrinsics(<8 x i32> %0, <8 x i32> %1, <8 x float> %2, <8 x float> %3, <8 x i64> %4, <8 x double> %5, <8 x ptr> %6, i32 %7, float %8, ptr %9, ptr %10, <8 x i1> %11, i32 %12) {
   ; CHECK: "llvm.intr.vp.add"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
@@ -1085,3 +1098,5 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
 declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
 declare void @llvm.assume(i1)
 declare float @llvm.ssa.copy.f32(float returned)
+declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
+declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)


        


More information about the Mlir-commits mailing list