[Mlir-commits] [mlir] 3aab320 - [MLIR][SPIRVToLLVM] Conversion for inverse sqrt and tanh

George Mitenkov llvmlistbot at llvm.org
Thu Jul 30 00:51:38 PDT 2020


Author: George Mitenkov
Date: 2020-07-30T10:50:48+03:00
New Revision: 3aab320557e7441bc2ce0b51fd6d82838fd0d484

URL: https://github.com/llvm/llvm-project/commit/3aab320557e7441bc2ce0b51fd6d82838fd0d484
DIFF: https://github.com/llvm/llvm-project/commit/3aab320557e7441bc2ce0b51fd6d82838fd0d484.diff

LOG: [MLIR][SPIRVToLLVM] Conversion for inverse sqrt and tanh

This is a second patch on conversion of GLSL ops to LLVM dialect.
It introduces patterns to convert `spv.InverseSqrt` and `spv.Tanh`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84633

Added: 
    

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 803b05a032da..58d160d30a49 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -91,6 +91,21 @@ static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
 }
 
+/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
+static Value createFPConstant(Location loc, Type srcType, Type dstType,
+                              PatternRewriter &rewriter, double value) {
+  if (auto vecType = srcType.dyn_cast<VectorType>()) {
+    auto floatType = vecType.getElementType().cast<FloatType>();
+    return rewriter.create<LLVM::ConstantOp>(
+        loc, dstType,
+        SplatElementsAttr::get(vecType,
+                               rewriter.getFloatAttr(floatType, value)));
+  }
+  auto floatType = srcType.cast<FloatType>();
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, dstType, rewriter.getFloatAttr(floatType, value));
+}
+
 /// Utility function for bitfiled ops:
 ///   - `BitFieldInsert`
 ///   - `BitFieldSExtract`
@@ -590,6 +605,27 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
   }
 };
 
+class InverseSqrtPattern
+    : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = op.getType();
+    auto dstType = typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    Location loc = op.getLoc();
+    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+    Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
+    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
+    return success();
+  }
+};
+
 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
 template <typename SPIRVop>
 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
@@ -821,6 +857,40 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
   }
 };
 
+/// Convert `spv.Tanh` to
+///
+///   exp(2x) - 1
+///   -----------
+///   exp(2x) + 1
+///
+class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = tanhOp.getType();
+    auto dstType = typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    Location loc = tanhOp.getLoc();
+    Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
+    Value multiplied =
+        rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
+    Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+    Value numerator =
+        rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+    Value denominator =
+        rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
+                                              denominator);
+    return success();
+  }
+};
+
 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
 public:
   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
@@ -1052,7 +1122,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
-      DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, TanPattern,
+      DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
+      InverseSqrtPattern, TanPattern, TanhPattern,
 
       // Logical ops
       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
index 1907619445c9..ab501b8aabfd 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
@@ -103,3 +103,33 @@ func @tan(%arg0: f32) {
 	%0 = spv.GLSL.Tan %arg0 : f32
 	return
 }
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Tanh
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @tanh
+func @tanh(%arg0: f32) {
+	// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : !llvm.float
+  // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[X2]]) : (!llvm.float) -> !llvm.float
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float
+  // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float
+  // CHECK: llvm.fdiv %[[T0]], %[[T1]] : !llvm.float
+	%0 = spv.GLSL.Tanh %arg0 : f32
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.InverseSqrt
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inverse_sqrt
+func @inverse_sqrt(%arg0: f32) {
+	// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%{{.*}}) : (!llvm.float) -> !llvm.float
+	// CHECK: llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+	%0 = spv.GLSL.InverseSqrt %arg0 : f32
+	return
+}


        


More information about the Mlir-commits mailing list