[Mlir-commits] [mlir] a4dc613 - [MLIR][SPIRVToLLVM] Implementation of spv.func conversion, and return ops

Lei Zhang llvmlistbot at llvm.org
Tue Jun 23 08:34:19 PDT 2020


Author: George Mitenkov
Date: 2020-06-23T11:34:11-04:00
New Revision: a4dc61344f08f7581b8794c5819ffdf9c708ecfe

URL: https://github.com/llvm/llvm-project/commit/a4dc61344f08f7581b8794c5819ffdf9c708ecfe
DIFF: https://github.com/llvm/llvm-project/commit/a4dc61344f08f7581b8794c5819ffdf9c708ecfe.diff

LOG: [MLIR][SPIRVToLLVM] Implementation of spv.func conversion, and return ops

This patch provides an implementation for `spv.func` conversion. The pattern
is populated in a separate method added to the pass. At the moment, the type
signature conversion only includes the supported types. The conversion pattern
also matches SPIR-V function control attributes to LLVM function attributes.
Those are modelled as `passthrough` attributes in LLVM dialect. The following
mapping are used:
- None: no attributes passed
- Inline: `alwaysinline` seems to be the right equivalent (`inlinehint` is
  semantically weaker in my opinion)
- DontInline: `noinline`
- Pure and Const: I think those can be modelled as `readonly` and `readnone`
  attributes respectively.

Also, 2 patterns added for return ops conversion (`spv.Return` for void return
and `spv.ReturnValue` for a single value return).

Differential Revision: https://reviews.llvm.org/D81931

Added: 
    mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index e82efac3abe5..178d27dff41b 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -37,6 +37,12 @@ void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,
                                            OwningRewritePatternList &patterns);
 
+/// Populates the given list with patterns for function conversion from SPIR-V
+/// to LLVM.
+void populateSPIRVToLLVMFunctionConversionPatterns(
+    MLIRContext *context, LLVMTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 85f6a113cff3..e32fdc5fe60a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -21,6 +21,9 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "spirv-to-llvm-pattern"
 
 using namespace mlir;
 
@@ -150,6 +153,32 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
   }
 };
 
+class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
+                                                ArrayRef<Value>());
+    return success();
+  }
+};
+
+class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
+                                                operands);
+    return success();
+  }
+};
+
 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
 /// puts a restriction on `Shift` and `Base` to have the same bit width,
 /// `Shift` is zero or sign extended to match this specification. Cases when
@@ -191,6 +220,64 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
     return success();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// FuncOp conversion
+//===----------------------------------------------------------------------===//
+
+class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Convert function signature. At the moment LLVMType converter is enough
+    // for currently supported types.
+    auto funcType = funcOp.getType();
+    TypeConverter::SignatureConversion signatureConverter(
+        funcType.getNumInputs());
+    auto llvmType = this->typeConverter.convertFunctionSignature(
+        funcOp.getType(), /*isVariadic=*/false, signatureConverter);
+
+    // Create a new `LLVMFuncOp`
+    Location loc = funcOp.getLoc();
+    StringRef name = funcOp.getName();
+    auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
+
+    // Convert SPIR-V Function Control to equivalent LLVM function attribute
+    MLIRContext *context = funcOp.getContext();
+    switch (funcOp.function_control()) {
+#define DISPATCH(functionControl, llvmAttr)                                    \
+  case functionControl:                                                        \
+    newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));     \
+    break;
+
+          DISPATCH(spirv::FunctionControl::Inline,
+                   StringAttr::get("alwaysinline", context));
+          DISPATCH(spirv::FunctionControl::DontInline,
+                   StringAttr::get("noinline", context));
+          DISPATCH(spirv::FunctionControl::Pure,
+                   StringAttr::get("readonly", context));
+          DISPATCH(spirv::FunctionControl::Const,
+                   StringAttr::get("readnone", context));
+
+#undef DISPATCH
+
+    // Default: if `spirv::FunctionControl::None`, then no attributes are
+    // needed.
+    default:
+      break;
+    }
+
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -263,6 +350,14 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       // Shift ops
       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
-      ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
-                                                            typeConverter);
+      ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
+
+      // Return ops
+      ReturnPattern, ReturnValuePattern>(context, typeConverter);
+}
+
+void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
+    MLIRContext *context, LLVMTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<FuncConversionPattern>(context, typeConverter);
 }

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
index 8f300541a71d..81a3a711433d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
@@ -35,6 +35,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
 
   OwningRewritePatternList patterns;
   populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
+  populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
 
   // Currently pulls in Std to LLVM conversion patterns
   // that help with testing. This allows to convert

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir
new file mode 100644
index 000000000000..5d2c491e785e
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.Return
+//===----------------------------------------------------------------------===//
+
+func @return() {
+	// CHECK: llvm.return
+	spv.Return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+func @return_value(%arg: i32) {
+	// CHECK: llvm.return %{{.*}} : !llvm.i32
+	spv.ReturnValue %arg : i32
+}
+
+//===----------------------------------------------------------------------===//
+// spv.func
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: llvm.func @none()
+spv.func @none() -> () "None" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @inline() attributes {passthrough = ["alwaysinline"]}
+spv.func @inline() -> () "Inline" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @dont_inline() attributes {passthrough = ["noinline"]}
+spv.func @dont_inline() -> () "DontInline" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @pure() attributes {passthrough = ["readonly"]}
+spv.func @pure() -> () "Pure" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @const() attributes {passthrough = ["readnone"]}
+spv.func @const() -> () "Const" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @scalar_types(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.double, %arg3: !llvm.float)
+spv.func @scalar_types(%arg0: i32, %arg1: i1, %arg2: f64, %arg3: f32) -> () "None" {
+	spv.Return
+}
+
+// CHECK-LABEL: llvm.func @vector_types(%arg0: !llvm<"<2 x i64>">, %arg1: !llvm<"<2 x i64>">) -> !llvm<"<2 x i64>">
+spv.func @vector_types(%arg0: vector<2xi64>, %arg1: vector<2xi64>) -> vector<2xi64> "None" {
+	%0 = spv.IAdd %arg0, %arg1 : vector<2xi64>
+	spv.ReturnValue %0 : vector<2xi64>
+}
+
+
+


        


More information about the Mlir-commits mailing list