[Mlir-commits] [mlir] MLIR: Enable importing inlineasm calls (PR #121624)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 3 22:02:47 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: William Moses (wsmoses)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/121624.diff


4 Files Affected:

- (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+4-1) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+59-46) 
- (modified) mlir/test/Target/LLVMIR/Import/import-failure.ll (-9) 
- (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+10) 


``````````diff
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index eea0647895b01b..880e8201318e6c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -319,9 +319,12 @@ class ModuleImport {
   /// Appends the converted result type and operands of `callInst` to the
   /// `types` and `operands` arrays. For indirect calls, the method additionally
   /// inserts the called function at the beginning of the `operands` array.
+  /// If `handleAsm` is set to false (the default), it will err if the handler
+  /// is an inline asm which isn't convertible to MLIR as a value.
   LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
                                            SmallVectorImpl<Type> &types,
-                                           SmallVectorImpl<Value> &operands);
+                                           SmallVectorImpl<Value> &operands,
+                                           bool handleAsm = false);
   /// Converts the parameter attributes attached to `func` and adds them to the
   /// `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index b0d5e635248d3f..66cfb3f9dca110 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1473,18 +1473,19 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
   return success();
 }
 
-LogicalResult
-ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
-                                         SmallVectorImpl<Type> &types,
-                                         SmallVectorImpl<Value> &operands) {
+LogicalResult ModuleImport::convertCallTypeAndOperands(
+    llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<Value> &operands, bool handleAsm) {
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
   if (!callInst->getCalledFunction()) {
-    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-    if (failed(called))
-      return failure();
-    operands.push_back(*called);
+    if (!handleAsm || !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+      if (failed(called))
+        return failure();
+      operands.push_back(*called);
+    }
   }
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,7 +1580,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(callInst, types, operands)))
+    if (failed(convertCallTypeAndOperands(callInst, types, operands, true)))
       return failure();
 
     auto funcTy =
@@ -1587,45 +1588,57 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (!funcTy)
       return failure();
 
-    CallOp callOp;
-
-    if (llvm::Function *callee = callInst->getCalledFunction()) {
-      callOp = builder.create<CallOp>(
-          loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-          operands);
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      InlineAsmOp callOp = builder.create<InlineAsmOp>(
+          loc, funcTy.getReturnType(), operands,
+          builder.getStringAttr(asmI->getAsmString()),
+          builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
+          nullptr, nullptr);
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult(0));
+      else
+        mapNoResultOp(inst, callOp);
     } else {
-      callOp = builder.create<CallOp>(loc, funcTy, operands);
+      CallOp callOp;
+
+      if (llvm::Function *callee = callInst->getCalledFunction()) {
+        callOp = builder.create<CallOp>(
+            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+            operands);
+      } else {
+        callOp = builder.create<CallOp>(loc, funcTy, operands);
+      }
+      callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+      callOp.setTailCallKind(
+          convertTailCallKindFromLLVM(callInst->getTailCallKind()));
+      setFastmathFlagsAttr(inst, callOp);
+
+      // Handle function attributes.
+      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
+        callOp.setConvergent(true);
+      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
+        callOp.setNoUnwind(true);
+      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
+        callOp.setWillReturn(true);
+
+      llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
+      ModRefInfo othermem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+      ModRefInfo argMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+      ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+      auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
+                                            argMem, inaccessibleMem);
+      // Only set the attribute when it does not match the default value.
+      if (!memAttr.isReadWrite())
+        callOp.setMemoryEffectsAttr(memAttr);
+
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult());
+      else
+        mapNoResultOp(inst, callOp);
     }
-    callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
-    callOp.setTailCallKind(
-        convertTailCallKindFromLLVM(callInst->getTailCallKind()));
-    setFastmathFlagsAttr(inst, callOp);
-
-    // Handle function attributes.
-    if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-      callOp.setConvergent(true);
-    if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-      callOp.setNoUnwind(true);
-    if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-      callOp.setWillReturn(true);
-
-    llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
-    ModRefInfo othermem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::Other));
-    ModRefInfo argMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
-    ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
-    auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
-                                          inaccessibleMem);
-    // Only set the attribute when it does not match the default value.
-    if (!memAttr.isReadWrite())
-      callOp.setMemoryEffectsAttr(memAttr);
-
-    if (!callInst->getType()->isVoidTy())
-      mapValue(inst, callOp.getResult());
-    else
-      mapNoResultOp(inst, callOp);
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::LandingPad) {
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 6bde174642d540..b616cb81e0a8a5 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -12,15 +12,6 @@ bb2:
 
 ; // -----
 
-; CHECK:      <unknown>
-; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
-define i32 @unhandled_value(i32 %arg1) {
-  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
-  ret i32 %1
-}
-
-; // -----
-
 ; CHECK:      <unknown>
 ; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
 ; CHECK:      <unknown>
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index fff48bbc486bc1..42b2cfff9611a3 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -535,6 +535,16 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
 
 ; // -----
 
+; CHECK-LABEL: @inlineasm
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define i32 @inlineasm(i32 %arg1) {
+  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
+  ; CHECK:  llvm.inline_asm has_side_effects asm_dialect = att "bswap $0", "=r,r" (%[[ARG1]]) : (i32) -> (i32)
+  ret i32 %1
+}
+
+; // -----
+
 ; CHECK-LABEL: @gep_static_idx
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @gep_static_idx(ptr %ptr) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/121624


More information about the Mlir-commits mailing list