[Mlir-commits] [mlir] MLIR: Enable importing inlineasm calls (PR #121624)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 3 22:02:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: William Moses (wsmoses)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/121624.diff
4 Files Affected:
- (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+4-1)
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+59-46)
- (modified) mlir/test/Target/LLVMIR/Import/import-failure.ll (-9)
- (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+10)
``````````diff
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index eea0647895b01b..880e8201318e6c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -319,9 +319,12 @@ class ModuleImport {
/// Appends the converted result type and operands of `callInst` to the
/// `types` and `operands` arrays. For indirect calls, the method additionally
/// inserts the called function at the beginning of the `operands` array.
+ /// If `handleAsm` is set to false (the default), it will err if the handler
+ /// is an inline asm which isn't convertible to MLIR as a value.
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
SmallVectorImpl<Type> &types,
- SmallVectorImpl<Value> &operands);
+ SmallVectorImpl<Value> &operands,
+ bool handleAsm = false);
/// Converts the parameter attributes attached to `func` and adds them to the
/// `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index b0d5e635248d3f..66cfb3f9dca110 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1473,18 +1473,19 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
return success();
}
-LogicalResult
-ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
- SmallVectorImpl<Type> &types,
- SmallVectorImpl<Value> &operands) {
+LogicalResult ModuleImport::convertCallTypeAndOperands(
+ llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<Value> &operands, bool handleAsm) {
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));
if (!callInst->getCalledFunction()) {
- FailureOr<Value> called = convertValue(callInst->getCalledOperand());
- if (failed(called))
- return failure();
- operands.push_back(*called);
+ if (!handleAsm || !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+ if (failed(called))
+ return failure();
+ operands.push_back(*called);
+ }
}
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,7 +1580,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
SmallVector<Type> types;
SmallVector<Value> operands;
- if (failed(convertCallTypeAndOperands(callInst, types, operands)))
+ if (failed(convertCallTypeAndOperands(callInst, types, operands, true)))
return failure();
auto funcTy =
@@ -1587,45 +1588,57 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (!funcTy)
return failure();
- CallOp callOp;
-
- if (llvm::Function *callee = callInst->getCalledFunction()) {
- callOp = builder.create<CallOp>(
- loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
- operands);
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ InlineAsmOp callOp = builder.create<InlineAsmOp>(
+ loc, funcTy.getReturnType(), operands,
+ builder.getStringAttr(asmI->getAsmString()),
+ builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
+ nullptr, nullptr);
+ if (!callInst->getType()->isVoidTy())
+ mapValue(inst, callOp.getResult(0));
+ else
+ mapNoResultOp(inst, callOp);
} else {
- callOp = builder.create<CallOp>(loc, funcTy, operands);
+ CallOp callOp;
+
+ if (llvm::Function *callee = callInst->getCalledFunction()) {
+ callOp = builder.create<CallOp>(
+ loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+ operands);
+ } else {
+ callOp = builder.create<CallOp>(loc, funcTy, operands);
+ }
+ callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+ callOp.setTailCallKind(
+ convertTailCallKindFromLLVM(callInst->getTailCallKind()));
+ setFastmathFlagsAttr(inst, callOp);
+
+ // Handle function attributes.
+ if (callInst->hasFnAttr(llvm::Attribute::Convergent))
+ callOp.setConvergent(true);
+ if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
+ callOp.setNoUnwind(true);
+ if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
+ callOp.setWillReturn(true);
+
+ llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
+ ModRefInfo othermem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+ ModRefInfo argMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+ ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+ auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
+ argMem, inaccessibleMem);
+ // Only set the attribute when it does not match the default value.
+ if (!memAttr.isReadWrite())
+ callOp.setMemoryEffectsAttr(memAttr);
+
+ if (!callInst->getType()->isVoidTy())
+ mapValue(inst, callOp.getResult());
+ else
+ mapNoResultOp(inst, callOp);
}
- callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
- callOp.setTailCallKind(
- convertTailCallKindFromLLVM(callInst->getTailCallKind()));
- setFastmathFlagsAttr(inst, callOp);
-
- // Handle function attributes.
- if (callInst->hasFnAttr(llvm::Attribute::Convergent))
- callOp.setConvergent(true);
- if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
- callOp.setNoUnwind(true);
- if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
- callOp.setWillReturn(true);
-
- llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
- ModRefInfo othermem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::Other));
- ModRefInfo argMem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
- ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
- auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
- inaccessibleMem);
- // Only set the attribute when it does not match the default value.
- if (!memAttr.isReadWrite())
- callOp.setMemoryEffectsAttr(memAttr);
-
- if (!callInst->getType()->isVoidTy())
- mapValue(inst, callOp.getResult());
- else
- mapNoResultOp(inst, callOp);
return success();
}
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 6bde174642d540..b616cb81e0a8a5 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -12,15 +12,6 @@ bb2:
; // -----
-; CHECK: <unknown>
-; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
-define i32 @unhandled_value(i32 %arg1) {
- %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
- ret i32 %1
-}
-
-; // -----
-
; CHECK: <unknown>
; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
; CHECK: <unknown>
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index fff48bbc486bc1..42b2cfff9611a3 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -535,6 +535,16 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
; // -----
+; CHECK-LABEL: @inlineasm
+; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+define i32 @inlineasm(i32 %arg1) {
+ %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
+ ; CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bswap $0", "=r,r" (%[[ARG1]]) : (i32) -> (i32)
+ ret i32 %1
+}
+
+; // -----
+
; CHECK-LABEL: @gep_static_idx
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
define void @gep_static_idx(ptr %ptr) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/121624
More information about the Mlir-commits
mailing list