[Mlir-commits] [mlir] 3e746c6 - [mlir] Add support for ExpM1 to GLSL/OpenCL SPIRV Backends

Rob Suderman llvmlistbot at llvm.org
Mon Jan 24 15:42:17 PST 2022


Author: Rob Suderman
Date: 2022-01-24T15:38:34-08:00
New Revision: 3e746c6d9ef0758c1e06901a99a75b638d6a5655

URL: https://github.com/llvm/llvm-project/commit/3e746c6d9ef0758c1e06901a99a75b638d6a5655
DIFF: https://github.com/llvm/llvm-project/commit/3e746c6d9ef0758c1e06901a99a75b638d6a5655.diff

LOG: [mlir] Add support for ExpM1 to GLSL/OpenCL SPIRV Backends

Adding a similar decomposition for exponential minus one to the SPIRV
backends along with the necessary tests.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D118081

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index ec8402af03009..90588ed9bd5f0 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -30,6 +30,28 @@ using namespace mlir;
 // normal RewritePattern.
 
 namespace {
+/// Converts math.expm1 to SPIR-V ops.
+///
+/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
+/// these operations.
+template <typename ExpOp>
+class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
+public:
+  using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    Location loc = operation.getLoc();
+    auto type = this->getTypeConverter()->convertType(operation.getType());
+    auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+    auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
+    return success();
+  }
+};
+
 /// Converts math.log1p to SPIR-V ops.
 ///
 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
@@ -44,11 +66,10 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() == 1);
     Location loc = operation.getLoc();
-    auto type =
-        this->getTypeConverter()->convertType(operation.getOperand().getType());
+    auto type = this->getTypeConverter()->convertType(operation.getType());
     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
     auto onePlus =
-        rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
+        rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
     return success();
   }
@@ -65,7 +86,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
   // GLSL patterns
   patterns
-      .add<Log1pOpPattern<spirv::GLSLLogOp>,
+      .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -81,7 +102,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
           typeConverter, patterns.getContext());
 
   // OpenCL patterns
-  patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
+  patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index f0e0b7e63fdce..c996e7056783a 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -8,26 +8,30 @@ func @float32_unary_scalar(%arg0: f32) {
   %0 = math.cos %arg0 : f32
   // CHECK: spv.GLSL.Exp %{{.*}}: f32
   %1 = math.exp %arg0 : f32
+  // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0
+  // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+  // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+  %2 = math.expm1 %arg0 : f32
   // CHECK: spv.GLSL.Log %{{.*}}: f32
-  %2 = math.log %arg0 : f32
+  %3 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spv.GLSL.Log %[[ADDONE]]
-  %3 = math.log1p %arg0 : f32
+  %4 = math.log1p %arg0 : f32
   // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
-  %4 = math.rsqrt %arg0 : f32
+  %5 = math.rsqrt %arg0 : f32
   // CHECK: spv.GLSL.Sqrt %{{.*}}: f32
-  %5 = math.sqrt %arg0 : f32
+  %6 = math.sqrt %arg0 : f32
   // CHECK: spv.GLSL.Tanh %{{.*}}: f32
-  %6 = math.tanh %arg0 : f32
+  %7 = math.tanh %arg0 : f32
   // CHECK: spv.GLSL.Sin %{{.*}}: f32
-  %7 = math.sin %arg0 : f32
+  %8 = math.sin %arg0 : f32
   // CHECK: spv.GLSL.FAbs %{{.*}}: f32
-  %8 = math.abs %arg0 : f32
+  %9 = math.abs %arg0 : f32
   // CHECK: spv.GLSL.Ceil %{{.*}}: f32
-  %9 = math.ceil %arg0 : f32
+  %10 = math.ceil %arg0 : f32
   // CHECK: spv.GLSL.Floor %{{.*}}: f32
-  %10 = math.floor %arg0 : f32
+  %11 = math.floor %arg0 : f32
   return
 }
 
@@ -37,20 +41,24 @@ func @float32_unary_vector(%arg0: vector<3xf32>) {
   %0 = math.cos %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32>
   %1 = math.exp %arg0 : vector<3xf32>
+  // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0
+  // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
+  // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+  %2 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32>
-  %2 = math.log %arg0 : vector<3xf32>
+  %3 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spv.GLSL.Log %[[ADDONE]]
-  %3 = math.log1p %arg0 : vector<3xf32>
+  %4 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32>
-  %4 = math.rsqrt %arg0 : vector<3xf32>
+  %5 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32>
-  %5 = math.sqrt %arg0 : vector<3xf32>
+  %6 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32>
-  %6 = math.tanh %arg0 : vector<3xf32>
+  %7 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32>
-  %7 = math.sin %arg0 : vector<3xf32>
+  %8 = math.sin %arg0 : vector<3xf32>
   return
 }
 

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 7580f1f733c49..d0959efc98ab2 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -8,28 +8,32 @@ func @float32_unary_scalar(%arg0: f32) {
   %0 = math.cos %arg0 : f32
   // CHECK: spv.OCL.exp %{{.*}}: f32
   %1 = math.exp %arg0 : f32
+  // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0
+  // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+  // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+  %2 = math.expm1 %arg0 : f32
   // CHECK: spv.OCL.log %{{.*}}: f32
-  %2 = math.log %arg0 : f32
+  %3 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spv.OCL.log %[[ADDONE]]
-  %3 = math.log1p %arg0 : f32
+  %4 = math.log1p %arg0 : f32
   // CHECK: spv.OCL.rsqrt %{{.*}}: f32
-  %4 = math.rsqrt %arg0 : f32
+  %5 = math.rsqrt %arg0 : f32
   // CHECK: spv.OCL.sqrt %{{.*}}: f32
-  %5 = math.sqrt %arg0 : f32
+  %6 = math.sqrt %arg0 : f32
   // CHECK: spv.OCL.tanh %{{.*}}: f32
-  %6 = math.tanh %arg0 : f32
+  %7 = math.tanh %arg0 : f32
   // CHECK: spv.OCL.sin %{{.*}}: f32
-  %7 = math.sin %arg0 : f32
+  %8 = math.sin %arg0 : f32
   // CHECK: spv.OCL.fabs %{{.*}}: f32
-  %8 = math.abs %arg0 : f32
+  %9 = math.abs %arg0 : f32
   // CHECK: spv.OCL.ceil %{{.*}}: f32
-  %9 = math.ceil %arg0 : f32
+  %10 = math.ceil %arg0 : f32
   // CHECK: spv.OCL.floor %{{.*}}: f32
-  %10 = math.floor %arg0 : f32
+  %11 = math.floor %arg0 : f32
   // CHECK: spv.OCL.erf %{{.*}}: f32
-  %11 = math.erf %arg0 : f32
+  %12 = math.erf %arg0 : f32
   return
 }
 
@@ -39,20 +43,24 @@ func @float32_unary_vector(%arg0: vector<3xf32>) {
   %0 = math.cos %arg0 : vector<3xf32>
   // CHECK: spv.OCL.exp %{{.*}}: vector<3xf32>
   %1 = math.exp %arg0 : vector<3xf32>
+  // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0
+  // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
+  // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+  %2 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spv.OCL.log %{{.*}}: vector<3xf32>
-  %2 = math.log %arg0 : vector<3xf32>
+  %3 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spv.OCL.log %[[ADDONE]]
-  %3 = math.log1p %arg0 : vector<3xf32>
+  %4 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32>
-  %4 = math.rsqrt %arg0 : vector<3xf32>
+  %5 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32>
-  %5 = math.sqrt %arg0 : vector<3xf32>
+  %6 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32>
-  %6 = math.tanh %arg0 : vector<3xf32>
+  %7 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spv.OCL.sin %{{.*}}: vector<3xf32>
-  %7 = math.sin %arg0 : vector<3xf32>
+  %8 = math.sin %arg0 : vector<3xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list