[llvm-branch-commits] [mlir] 42980a7 - [mlir][spirv] Convert functions returning one value
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 23 10:41:04 PST 2020
Author: Lei Zhang
Date: 2020-12-23T13:27:31-05:00
New Revision: 42980a789d2212f774dbb12c2555452d328089a6
URL: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6
DIFF: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6.diff
LOG: [mlir][spirv] Convert functions returning one value
Reviewed By: hanchung, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D93468
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index d15623568212..470f4143f2c5 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -924,10 +924,14 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
LogicalResult
ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (returnOp.getNumOperands()) {
+ if (returnOp.getNumOperands() > 1)
return failure();
+
+ if (returnOp.getNumOperands() == 1) {
+ rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
+ } else {
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
}
- rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index b310d5df7b26..9393f3df6425 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -473,23 +473,27 @@ LogicalResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
- // TODO: support converting functions with one result.
- if (fnType.getNumResults())
+ if (fnType.getNumResults() > 1)
return failure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- for (auto argType : enumerate(funcOp.getType().getInputs())) {
+ for (auto argType : enumerate(fnType.getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
if (!convertedType)
return failure();
signatureConverter.addInputs(argType.index(), convertedType);
}
+ Type resultType;
+ if (fnType.getNumResults() == 1)
+ resultType = typeConverter.convertType(fnType.getResult(0));
+
// Create the converted spv.func op.
auto newFuncOp = rewriter.create<spirv::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
- llvm::None));
+ resultType ? TypeRange(resultType)
+ : TypeRange()));
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp.getAttrs()) {
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 10e43ef4acd7..850e22465d44 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -954,3 +954,29 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
}
} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.return
+//===----------------------------------------------------------------------===//
+
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: spv.func @return_one_val
+// CHECK-SAME: (%[[ARG:.+]]: f32)
+func @return_one_val(%arg0: f32) -> f32 {
+ // CHECK: spv.ReturnValue %[[ARG]] : f32
+ return %arg0: f32
+}
+
+// Check that multiple-return functions are not converted.
+// CHECK-LABEL: func @return_multi_val
+func @return_multi_val(%arg0: f32) -> (f32, f32) {
+ // CHECK: return
+ return %arg0, %arg0: f32, f32
+}
+
+}
More information about the llvm-branch-commits
mailing list