[Mlir-commits] [mlir] c1f719d - [mlir] harden result type verification in llvm.call
Alex Zinenko
llvmlistbot at llvm.org
Wed Jul 28 09:16:05 PDT 2021
Author: Alex Zinenko
Date: 2021-07-28T18:15:56+02:00
New Revision: c1f719d1a749eaf4a4964292e3eed6ab2766f2c5
URL: https://github.com/llvm/llvm-project/commit/c1f719d1a749eaf4a4964292e3eed6ab2766f2c5
DIFF: https://github.com/llvm/llvm-project/commit/c1f719d1a749eaf4a4964292e3eed6ab2766f2c5.diff
LOG: [mlir] harden result type verification in llvm.call
The verifier of the llvm.call operation was not checking for mismatches between
the number of operation results and the number of results in the signature of
the callee. Furthermore, it was possible to construct an llvm.call operation
producing an SSA value of !llvm.void type, which should not exist. Add the
verification and treat !llvm.void result type as absence of call results.
Update the GPU conversions to LLVM that were mistakenly assuming that it was
fine for llvm.call to produce values of !llvm.void type and ensure these calls
do not produce results.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D106937
Added:
Modified:
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 5269c16354bc2..3dfed5fec7a26 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -353,9 +353,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
- return builder.create<LLVM::CallOp>(
- loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
- builder.getSymbolRefAttr(function), arguments);
+ return builder.create<LLVM::CallOp>(loc, function, arguments);
}
// Returns whether all operands are of LLVM type.
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 118941539e0c2..933a158aff640 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -248,7 +248,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
}
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
- loc, TypeRange{getVoidType()},
+ loc, TypeRange(),
builder.getSymbolRefAttr(
StringRef(symbolName.data(), symbolName.size())),
ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
@@ -396,32 +396,31 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// Create call to `setBinaryShader` runtime function with the given pointer to
// SPIR-V binary and binary size.
builder.create<LLVM::CallOp>(
- loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
+ loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader),
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
Value entryPointName = createEntryPointNameConstant(
spirvAttributes.second.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
- builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kSetEntryPoint),
ValueRange{vulkanRuntime, entryPointName});
// Create number of local workgroup for each dimension.
builder.create<LLVM::CallOp>(
- loc, TypeRange{getVoidType()},
- builder.getSymbolRefAttr(kSetNumWorkGroups),
+ loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups),
ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
cInterfaceVulkanLaunchCallOp.getOperand(1),
cInterfaceVulkanLaunchCallOp.getOperand(2)});
// Create call to `runOnVulkan` runtime function.
- builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kRunOnVulkan),
ValueRange{vulkanRuntime});
// Create call to 'deinitVulkan' runtime function.
- builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kDeinitVulkan),
ValueRange{vulkanRuntime});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b9625b62a02d..bb71eff459ae5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -815,6 +815,19 @@ static LogicalResult verify(CallOp &op) {
<< ": " << op.getOperand(i + isIndirect).getType()
<< " != " << funcType.getParamType(i);
+ if (op.getNumResults() == 0 &&
+ !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
+ return op.emitOpError() << "expected function call to produce a value";
+
+ if (op.getNumResults() != 0 &&
+ funcType.getReturnType().isa<LLVM::LLVMVoidType>())
+ return op.emitOpError()
+ << "calling function with void result must not produce values";
+
+ if (op.getNumResults() > 1)
+ return op.emitOpError()
+ << "expected LLVM function call to produce 0 or 1 result";
+
if (op.getNumResults() &&
op.getResult(0).getType() != funcType.getReturnType())
return op.emitOpError()
@@ -874,19 +887,18 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType)
return parser.emitError(trailingTypeLoc, "expected function type");
+ if (funcType.getNumResults() > 1)
+ return parser.emitError(trailingTypeLoc,
+ "expected function with 0 or 1 result");
if (isDirect) {
// Make sure types match.
if (parser.resolveOperands(operands, funcType.getInputs(),
parser.getNameLoc(), result.operands))
return failure();
- result.addTypes(funcType.getResults());
+ if (funcType.getNumResults() != 0 &&
+ !funcType.getResult(0).isa<LLVM::LLVMVoidType>())
+ result.addTypes(funcType.getResults());
} else {
- // Construct the LLVM IR Dialect function type that the first operand
- // should match.
- if (funcType.getNumResults() > 1)
- return parser.emitError(trailingTypeLoc,
- "expected function with 0 or 1 result");
-
Builder &builder = parser.getBuilder();
Type llvmResultType;
if (funcType.getNumResults() == 0) {
@@ -921,7 +933,8 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands))
return failure();
- result.addTypes(llvmResultType);
+ if (!llvmResultType.isa<LLVM::LLVMVoidType>())
+ result.addTypes(llvmResultType);
}
return success();
diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
index 38796013b6e97..b99584ba126f0 100644
--- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
@@ -6,14 +6,14 @@
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> !llvm.void
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> !llvm.void
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.void
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> !llvm.void
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 19b80b86c43f7..38b7e4023c9b2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1089,3 +1089,33 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
llvm.return
}
+
+// -----
+
+llvm.func @caller() {
+ // expected-error @below {{expected function call to produce a value}}
+ llvm.call @callee() : () -> ()
+ llvm.return
+}
+
+llvm.func @callee() -> i32
+
+// -----
+
+llvm.func @caller() {
+ // expected-error @below {{calling function with void result must not produce values}}
+ %0 = llvm.call @callee() : () -> i32
+ llvm.return
+}
+
+llvm.func @callee() -> ()
+
+// -----
+
+llvm.func @caller() {
+ // expected-error @below {{expected function with 0 or 1 result}}
+ %0:2 = llvm.call @callee() : () -> (i32, f32)
+ llvm.return
+}
+
+llvm.func @callee() -> !llvm.struct<(i32, f32)>
More information about the Mlir-commits
mailing list