[Mlir-commits] [mlir] 5e0c3b4 - [mlir][LLVMIR] Clean up the definitions of ReturnOp/CallOp

Jeff Niu llvmlistbot at llvm.org
Wed Aug 10 21:35:17 PDT 2022


Author: Jeff Niu
Date: 2022-08-11T00:35:02-04:00
New Revision: 5e0c3b4309df8ad74ff096593c3a1dd28f8dd571

URL: https://github.com/llvm/llvm-project/commit/5e0c3b4309df8ad74ff096593c3a1dd28f8dd571
DIFF: https://github.com/llvm/llvm-project/commit/5e0c3b4309df8ad74ff096593c3a1dd28f8dd571.diff

LOG: [mlir][LLVMIR] Clean up the definitions of ReturnOp/CallOp

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 406f838f14606..9de4334a9d70b 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -14,6 +14,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
+class CallOpInterface;
 
 namespace LLVM {
 namespace detail {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 7a9167c5151f2..ac86e8461d277 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -29,10 +29,9 @@ namespace LLVM {
 class LLVMFuncOp;
 
 /// Helper functions to lookup or create the declaration for commonly used
-/// external C function calls. Such ops can then be invoked by creating a CallOp
-/// with the proper arguments via `createLLVMCall`.
-/// The list of functions provided here must be implemented separately (e.g. as
-/// part of a support runtime library or as part of the libc).
+/// external C function calls. The list of functions provided here must be
+/// implemented separately (e.g. as part of a support runtime library or as part
+/// of the libc).
 LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
@@ -58,12 +57,6 @@ LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
                                   Type resultType = {});
 
-/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
-Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
-                                       LLVM::LLVMFuncOp fn,
-                                       ValueRange args = {},
-                                       ArrayRef<Type> resultTypes = {});
-
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 203f35f36c197..84fc59243dc4f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -645,14 +645,16 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
 def LLVM_CallOp : LLVM_Op<"call",
                           [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                            DeclareOpInterfaceMethods<CallOpInterface>,
                            DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
-
-
     In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
     implements this behavior by providing a variadic `call` operation for 0- and
     1-result functions. Even though MLIR supports multi-result functions, LLVM
@@ -678,29 +680,20 @@ def LLVM_CallOp : LLVM_Op<"call",
     llvm.call %1(%0) : (f32) -> ()
     ```
   }];
+
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
-                   Variadic<LLVM_Type>,
-                   DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
-  let results = (outs Variadic<LLVM_Type>);
+                       Variadic<LLVM_Type>,
+                       DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+  let results = (outs Optional<LLVM_Type>:$result);
+
   let builders = [
-    OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
-      Type resultType = func.getFunctionType().getReturnType();
-      if (!resultType.isa<LLVM::LLVMVoidType>())
-        $_state.addTypes(resultType);
-      $_state.addAttribute("callee", SymbolRefAttr::get(func));
-      $_state.addAttributes(attributes);
-      $_state.addOperands(operands);
-    }]>,
+    OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
-                   CArg<"ValueRange", "{}">:$operands), [{
-      build($_builder, $_state, results, SymbolRefAttr::get(callee), operands);
-    }]>,
+                   CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
-                   CArg<"ValueRange", "{}">:$operands), [{
-      build($_builder, $_state, results,
-            StringAttr::get($_builder.getContext(), callee), operands);
-    }]>];
+                   CArg<"ValueRange", "{}">:$args)>
+  ];
+
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -878,25 +871,38 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
             falseOperands);
   }]>, LLVM_TerminatorPassthroughOpBuilder];
 }
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
 def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
-  let arguments = (ins Variadic<LLVM_Type>:$args);
+  let arguments = (ins Optional<LLVM_Type>:$arg);
+  let assemblyFormat = "attr-dict ($arg^ `:` type($arg))?";
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$args), [{
+      build($_builder, $_state, TypeRange(), args);
+    }]>
+  ];
+
+  let hasVerifier = 1;
+
   string llvmBuilder = [{
     if ($_numOperands != 0)
-      builder.CreateRet($args[0]);
+      builder.CreateRet($arg[0]);
     else
       builder.CreateRetVoid();
   }];
-
-  let assemblyFormat = "attr-dict ($args^ `:` type($args))?";
-  let hasVerifier = 1;
 }
-def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
+
+def LLVM_ResumeOp : LLVM_TerminatorOp<"resume"> {
   let arguments = (ins LLVM_Type:$value);
   string llvmBuilder = [{ builder.CreateResume($value); }];
   let assemblyFormat = "$value attr-dict `:` type($value)";
   let hasVerifier = 1;
 }
-def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
+def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
   string llvmBuilder = [{ builder.CreateUnreachable(); }];
   let assemblyFormat = "attr-dict";
 }

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index e7930653aee1a..aa6598c26853d 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -350,8 +350,8 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
     // requires the size parameter be an integral multiple of the alignment
     // parameter.
     auto makeConstant = [&](uint64_t c) {
-      return rewriter.create<LLVM::ConstantOp>(
-          op->getLoc(), rewriter.getI64Type(), c);
+      return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
+                                               rewriter.getI64Type(), c);
     };
     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
     coroSize =
@@ -365,13 +365,12 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
-        loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
-        ValueRange{coroAlign, coroSize});
+        loc, allocFuncOp, ValueRange{coroAlign, coroSize});
 
     // Begin a coroutine: @llvm.coro.begin.
     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
-        op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
+        op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
 
     return success();
   }
@@ -400,8 +399,7 @@ class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
     // Free the memory.
     auto freeFuncOp =
         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
-                                              SymbolRefAttr::get(freeFuncOp),
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
                                               ValueRange(coroMem.getResult()));
 
     return success();

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index e8feb90198c1a..1d7e28291fc23 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -164,7 +164,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
 
   if (resultIsNowArg) {
-    rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
+    rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
                                    wrapperFuncOp.getArgument(0));
     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
   } else {
@@ -265,7 +265,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
 
   if (resultIsNowArg) {
     Value result = builder.create<LLVM::LoadOp>(loc, args.front());
-    builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
+    builder.create<LLVM::ReturnOp>(loc, result);
   } else {
     builder.create<LLVM::ReturnOp>(loc, call.getResults());
   }
@@ -617,12 +617,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
     }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
-                                                  op->getAttrs());
-      return success();
-    }
-    if (numArguments == 1) {
+    if (numArguments <= 1) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
           op, TypeRange(), updatedOperands, op->getAttrs());
       return success();
@@ -630,13 +625,13 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
-    auto packedType = getTypeConverter()->packFunctionResults(
-        llvm::to_vector<4>(op.getOperandTypes()));
+    auto packedType =
+        getTypeConverter()->packFunctionResults(op.getOperandTypes());
 
     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
-    for (unsigned i = 0; i < numArguments; ++i) {
-      packed = rewriter.create<LLVM::InsertValueOp>(loc, packed,
-                                                    updatedOperands[i], i);
+    for (auto &it : llvm::enumerate(updatedOperands)) {
+      packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, it.value(),
+                                                    it.index());
     }
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
                                                 op->getAttrs());

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a62191d586787..459180b9d9e4c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -220,7 +220,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   /// Start the printf hostcall
   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
-  Value printfDesc = printfBeginCall.getResult(0);
+  Value printfDesc = printfBeginCall.getResult();
 
   // Create a global constant for the format string
   unsigned stringNumber = 0;
@@ -259,7 +259,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
       loc, ocklAppendStringN,
       ValueRange{printfDesc, stringStart, stringLen,
                  adaptor.args().empty() ? oneI32 : zeroI32});
-  printfDesc = appendFormatCall.getResult(0);
+  printfDesc = appendFormatCall.getResult();
 
   // __ockl_printf_append_args takes 7 values per append call
   constexpr size_t argsPerAppend = 7;
@@ -293,7 +293,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
     arguments.push_back(isLast);
     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
-    printfDesc = call.getResult(0);
+    printfDesc = call.getResult();
   }
   rewriter.eraseOp(gpuPrintfOp);
   return success();

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 2704e1408f769..8756603acf3e6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -482,7 +482,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
   Type elementPtrType = this->getElementPtrType(memRefType);
   auto stream = adaptor.asyncDependencies().front();
   Value allocatedPtr =
-      allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
+      allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult();
   allocatedPtr =
       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
 
@@ -539,7 +539,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
       continue;
     auto idx = operand.getOperandNumber();
     auto stream = adaptor.getOperands()[idx];
-    auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+    auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
     newOperands[idx] = event;
     streams.insert(stream);
@@ -612,8 +612,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
       // into the stream just after the last use of the original token operand.
       auto *defOp = std::get<0>(pair).getDefiningOp();
       rewriter.setInsertionPointAfter(defOp);
-      auto event =
-          eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+      auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
       events.push_back(event);
     } else {
@@ -623,7 +622,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
     }
   }
   rewriter.restoreInsertionPoint(insertionPoint);
-  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
   for (auto event : events)
     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
   for (auto event : events)
@@ -784,11 +783,11 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
       launchOp.getKernelModuleName().getValue(),
       launchOp.getKernelName().getValue(), loc, rewriter);
   auto function = moduleGetFunctionCallBuilder.create(
-      loc, rewriter, {module.getResult(0), kernelName});
+      loc, rewriter, {module.getResult(), kernelName});
   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
   Value stream =
       adaptor.asyncDependencies().empty()
-          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
+          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
           : adaptor.asyncDependencies().front();
   // Create array of pointers to kernel arguments.
   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
@@ -798,7 +797,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
                                       : zero;
   launchKernelCallBuilder.create(
       loc, rewriter,
-      {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
+      {function.getResult(), adaptor.gridSizeX(), adaptor.gridSizeY(),
        adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
        adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
        /*extra=*/nullpointer});
@@ -814,7 +813,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
     streamDestroyCallBuilder.create(loc, rewriter, stream);
     rewriter.eraseOp(launchOp);
   }
-  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
+  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
 
   return success();
 }

diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 2cf16934acdc0..0168e6e016449 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -60,17 +60,17 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
       return failure();
 
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
-    auto callOp = rewriter.create<LLVM::CallOp>(
-        op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
+    auto callOp =
+        rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
 
     if (resultType == adaptor.getOperands().front().getType()) {
-      rewriter.replaceOp(op, {callOp.getResult(0)});
+      rewriter.replaceOp(op, {callOp.getResult()});
       return success();
     }
 
     Value truncated = rewriter.create<LLVM::FPTruncOp>(
         op->getLoc(), adaptor.getOperands().front().getType(),
-        callOp.getResult(0));
+        callOp.getResult());
     rewriter.replaceOp(op, {truncated});
     return success();
   }

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 09e3ec41a43f6..3242add0799d1 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -374,7 +374,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
       loc, TypeRange{getPointerType()}, kInitVulkan);
   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
   // need to pass that pointer to each Vulkan runtime call.
-  auto vulkanRuntime = initVulkanCall.getResult(0);
+  auto vulkanRuntime = initVulkanCall.getResult();
 
   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
   // that data to runtime call.

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 15839695978b7..73bb64d278237 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -273,7 +273,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     Value memory =
         toDynamic
             ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
-                  .getResult(0)
+                  .getResult()
             : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
                                              /*alignment=*/0);
     Value source = desc.memRefDescPtr(builder, loc);

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ba92ae9ebfb59..8b3348f039b28 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -71,10 +71,9 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
     // descriptor.
     Type elementPtrType = this->getElementPtrType(memRefType);
     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
-    auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
-                                  getVoidPtrType());
-    Value allocatedPtr =
-        rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
+    auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+    Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+                                                          results.getResult());
 
     Value alignedPtr = allocatedPtr;
     if (alignment) {
@@ -168,11 +167,10 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
 
     Type elementPtrType = this->getElementPtrType(memRefType);
     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
-    auto results =
-        createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
-                       getVoidPtrType());
-    Value allocatedPtr =
-        rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
+    auto results = rewriter.create<LLVM::CallOp>(
+        loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+    Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+                                                          results.getResult());
 
     return std::make_tuple(allocatedPtr, allocatedPtr);
   }
@@ -330,8 +328,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
     Value casted = rewriter.create<LLVM::BitcastOp>(
         op.getLoc(), getVoidPtrType(),
         memref.allocatedPtr(rewriter, op.getLoc()));
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
-        op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, casted);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 65e26aaed5672..409c513f69ddb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -149,13 +149,3 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
-
-Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
-                                                   LLVM::LLVMFuncOp fn,
-                                                   ValueRange paramTypes,
-                                                   ArrayRef<Type> resultTypes) {
-  return b
-      .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
-                            paramTypes)
-      ->getResults();
-}

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9706e4b2e622e..7c4f1a4c827f4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1145,9 +1145,28 @@ ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 //===----------------------------------------------------------------------===//
-// Verifying/Printing/parsing for LLVM::CallOp.
+// CallOp
 //===----------------------------------------------------------------------===//
 
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+                   StringRef callee, ValueRange args) {
+  build(builder, state, results, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+                   StringAttr callee, ValueRange args) {
+  build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
+                   ValueRange args) {
+  SmallVector<Type> results;
+  Type resultType = func.getFunctionType().getReturnType();
+  if (!resultType.isa<LLVM::LLVMVoidType>())
+    results.push_back(resultType);
+  build(builder, state, results, SymbolRefAttr::get(func), args, nullptr);
+}
+
 CallInterfaceCallable CallOp::getCallableForCallee() {
   // Direct call.
   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
@@ -1235,8 +1254,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return emitOpError()
            << "expected LLVM function call to produce 0 or 1 result";
 
-  if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
-    return emitOpError() << "result type mismatch: " << getResult(0).getType()
+  if (getNumResults() && getResult().getType() != funcType.getReturnType())
+    return emitOpError() << "result type mismatch: " << getResult().getType()
                          << " != " << funcType.getReturnType();
 
   return success();
@@ -1608,34 +1627,33 @@ void InsertValueOp::build(OpBuilder &builder, OperationState &state,
 }
 
 //===----------------------------------------------------------------------===//
-// Printing, parsing and verification for LLVM::ReturnOp.
+// ReturnOp
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReturnOp::verify() {
-  if (getNumOperands() > 1)
-    return emitOpError("expected at most 1 operand");
-
-  if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
-    Type expectedType = parent.getFunctionType().getReturnType();
-    if (expectedType.isa<LLVMVoidType>()) {
-      if (getNumOperands() == 0)
-        return success();
-      InFlightDiagnostic diag = emitOpError("expected no operands");
-      diag.attachNote(parent->getLoc()) << "when returning from function";
-      return diag;
-    }
-    if (getNumOperands() == 0) {
-      if (expectedType.isa<LLVMVoidType>())
-        return success();
-      InFlightDiagnostic diag = emitOpError("expected 1 operand");
-      diag.attachNote(parent->getLoc()) << "when returning from function";
-      return diag;
-    }
-    if (expectedType != getOperand(0).getType()) {
-      InFlightDiagnostic diag = emitOpError("mismatching result types");
-      diag.attachNote(parent->getLoc()) << "when returning from function";
-      return diag;
-    }
+  auto parent = (*this)->getParentOfType<LLVMFuncOp>();
+  if (!parent)
+    return success();
+
+  Type expectedType = parent.getFunctionType().getReturnType();
+  if (expectedType.isa<LLVMVoidType>()) {
+    if (!getArg())
+      return success();
+    InFlightDiagnostic diag = emitOpError("expected no operands");
+    diag.attachNote(parent->getLoc()) << "when returning from function";
+    return diag;
+  }
+  if (!getArg()) {
+    if (expectedType.isa<LLVMVoidType>())
+      return success();
+    InFlightDiagnostic diag = emitOpError("expected 1 operand");
+    diag.attachNote(parent->getLoc()) << "when returning from function";
+    return diag;
+  }
+  if (expectedType != getArg().getType()) {
+    InFlightDiagnostic diag = emitOpError("mismatching result types");
+    diag.attachNote(parent->getLoc()) << "when returning from function";
+    return diag;
   }
   return success();
 }


        


More information about the Mlir-commits mailing list