[llvm-branch-commits] [mlir] 42980a7 - [mlir][spirv] Convert functions returning one value

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 23 10:41:04 PST 2020


Author: Lei Zhang
Date: 2020-12-23T13:27:31-05:00
New Revision: 42980a789d2212f774dbb12c2555452d328089a6

URL: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6
DIFF: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6.diff

LOG: [mlir][spirv] Convert functions returning one value

Reviewed By: hanchung, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D93468

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index d15623568212..470f4143f2c5 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -924,10 +924,14 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
 LogicalResult
 ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
                                  ConversionPatternRewriter &rewriter) const {
-  if (returnOp.getNumOperands()) {
+  if (returnOp.getNumOperands() > 1)
     return failure();
+
+  if (returnOp.getNumOperands() == 1) {
+    rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
+  } else {
+    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
   }
-  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index b310d5df7b26..9393f3df6425 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -473,23 +473,27 @@ LogicalResult
 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
                                   ConversionPatternRewriter &rewriter) const {
   auto fnType = funcOp.getType();
-  // TODO: support converting functions with one result.
-  if (fnType.getNumResults())
+  if (fnType.getNumResults() > 1)
     return failure();
 
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  for (auto argType : enumerate(funcOp.getType().getInputs())) {
+  for (auto argType : enumerate(fnType.getInputs())) {
     auto convertedType = typeConverter.convertType(argType.value());
     if (!convertedType)
       return failure();
     signatureConverter.addInputs(argType.index(), convertedType);
   }
 
+  Type resultType;
+  if (fnType.getNumResults() == 1)
+    resultType = typeConverter.convertType(fnType.getResult(0));
+
   // Create the converted spv.func op.
   auto newFuncOp = rewriter.create<spirv::FuncOp>(
       funcOp.getLoc(), funcOp.getName(),
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
-                               llvm::None));
+                               resultType ? TypeRange(resultType)
+                                          : TypeRange()));
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp.getAttrs()) {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 10e43ef4acd7..850e22465d44 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -954,3 +954,29 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
 }
 
 } // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.return
+//===----------------------------------------------------------------------===//
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: spv.func @return_one_val
+//  CHECK-SAME: (%[[ARG:.+]]: f32)
+func @return_one_val(%arg0: f32) -> f32 {
+  // CHECK: spv.ReturnValue %[[ARG]] : f32
+  return %arg0: f32
+}
+
+// Check that multiple-return functions are not converted.
+// CHECK-LABEL: func @return_multi_val
+func @return_multi_val(%arg0: f32) -> (f32, f32) {
+  // CHECK: return
+  return %arg0, %arg0: f32, f32
+}
+
+}


        


More information about the llvm-branch-commits mailing list