[Mlir-commits] [mlir] 58cdb8b - [mlir][StandardToSPIRV] Add support for lowering unary ops

Lei Zhang llvmlistbot at llvm.org
Tue Mar 24 06:17:13 PDT 2020


Author: Hanhan Wang
Date: 2020-03-24T09:16:10-04:00
New Revision: 58cdb8bff067a521dd68d6b699b13da74188a68b

URL: https://github.com/llvm/llvm-project/commit/58cdb8bff067a521dd68d6b699b13da74188a68b
DIFF: https://github.com/llvm/llvm-project/commit/58cdb8bff067a521dd68d6b699b13da74188a68b.diff

LOG: [mlir][StandardToSPIRV] Add support for lowering unary ops

Differential Revision: https://reviews.llvm.org/D76661

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 69ef69d1de65..ea8812cebdc4 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -107,16 +107,16 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
 
 namespace {
 
-/// Converts binary standard operations to SPIR-V operations.
+/// Converts unary and binary standard operations to SPIR-V operations.
 template <typename StdOp, typename SPIRVOp>
-class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
+class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
 public:
   using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
 
   LogicalResult
   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() == 2);
+    assert(operands.size() <= 2);
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
@@ -572,21 +572,31 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
   patterns.insert<
-      BinaryOpPattern<AddFOp, spirv::FAddOp>,
-      BinaryOpPattern<AddIOp, spirv::IAddOp>,
-      BinaryOpPattern<DivFOp, spirv::FDivOp>,
-      BinaryOpPattern<MulFOp, spirv::FMulOp>,
-      BinaryOpPattern<MulIOp, spirv::IMulOp>,
-      BinaryOpPattern<RemFOp, spirv::FRemOp>,
-      BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
-      BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
-      BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
-      BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
-      BinaryOpPattern<SubFOp, spirv::FSubOp>,
-      BinaryOpPattern<SubIOp, spirv::ISubOp>,
-      BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
-      BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
-      BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
+      UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
+      UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
+      UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
+      UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
+      UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
+      UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
+      UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
+      UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
+      UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
+      UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
+      UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
+      UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
+      UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
+      UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
+      UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
+      UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
+      UnaryAndBinaryOpPattern<SignedShiftRightOp,
+                              spirv::ShiftRightArithmeticOp>,
+      UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
+      UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
+      UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
+      UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
+      UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
+      UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
+      UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
       ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index cb5873a1baf0..91219acc0bd5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -31,9 +31,33 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
   return
 }
 
-// Check float operation conversions.
-// CHECK-LABEL: @float32_scalar
-func @float32_scalar(%lhs: f32, %rhs: f32) {
+// Check float unary operation conversions.
+// CHECK-LABEL: @float32_unary_scalar
+func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spv.GLSL.FAbs %{{.*}}: f32
+  %0 = absf %arg0 : f32
+  // CHECK: spv.GLSL.Ceil %{{.*}}: f32
+  %1 = ceilf %arg0 : f32
+  // CHECK: spv.GLSL.Cos %{{.*}}: f32
+  %2 = cos %arg0 : f32
+  // CHECK: spv.GLSL.Exp %{{.*}}: f32
+  %3 = exp %arg0 : f32
+  // CHECK: spv.GLSL.Log %{{.*}}: f32
+  %4 = log %arg0 : f32
+  // CHECK: spv.FNegate %{{.*}}: f32
+  %5 = negf %arg0 : f32
+  // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
+  %6 = rsqrt %arg0 : f32
+  // CHECK: spv.GLSL.Sqrt %{{.*}}: f32
+  %7 = sqrt %arg0 : f32
+  // CHECK: spv.GLSL.Tanh %{{.*}}: f32
+  %8 = tanh %arg0 : f32
+  return
+}
+
+// Check float binary operation conversions.
+// CHECK-LABEL: @float32_binary_scalar
+func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
   %0 = addf %lhs, %rhs: f32
   // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32


        


More information about the Mlir-commits mailing list