[Mlir-commits] [mlir] 533ec92 - [mlir][spirv] Add pattern to lower math.copysign

Lei Zhang llvmlistbot at llvm.org
Fri Apr 1 09:06:53 PDT 2022


Author: Lei Zhang
Date: 2022-04-01T12:06:47-04:00
New Revision: 533ec929f67d0169f247cb73835a127053cc5933

URL: https://github.com/llvm/llvm-project/commit/533ec929f67d0169f247cb73835a127053cc5933
DIFF: https://github.com/llvm/llvm-project/commit/533ec929f67d0169f247cb73835a127053cc5933.diff

LOG: [mlir][spirv] Add pattern to lower math.copysign

This follows the logic:
https://git.musl-libc.org/cgit/musl/tree/src/math/copysignf.c

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D122910

Added: 
    mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6d309d6760830..956118a5d9f77 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "math-to-spirv-pattern"
@@ -30,14 +31,74 @@ using namespace mlir;
 // normal RewritePattern.
 
 namespace {
+/// Converts math.copysign to SPIR-V ops.
+class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = getTypeConverter()->convertType(copySignOp.getType());
+    if (!type)
+      return failure();
+
+    FloatType floatType;
+    if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
+      floatType = scalarType;
+    } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
+      floatType = vectorType.getElementType().cast<FloatType>();
+    } else {
+      return failure();
+    }
+
+    Location loc = copySignOp.getLoc();
+    int bitwidth = floatType.getWidth();
+    Type intType = rewriter.getIntegerType(bitwidth);
+
+    Value signMask = rewriter.create<spirv::ConstantOp>(
+        loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1))));
+    Value valueMask = rewriter.create<spirv::ConstantOp>(
+        loc, intType,
+        rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u));
+
+    if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
+      assert(vectorType.getRank() == 1);
+      int count = vectorType.getNumElements();
+      intType = VectorType::get(count, intType);
+
+      SmallVector<Value> signSplat(count, signMask);
+      signMask =
+          rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
+
+      SmallVector<Value> valueSplat(count, valueMask);
+      valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
+                                                               valueSplat);
+    }
+
+    Value lhsCast =
+        rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
+    Value rhsCast =
+        rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
+
+    Value value = rewriter.create<spirv::BitwiseAndOp>(
+        loc, intType, ValueRange{lhsCast, valueMask});
+    Value sign = rewriter.create<spirv::BitwiseAndOp>(
+        loc, intType, ValueRange{rhsCast, signMask});
+
+    Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
+                                                       ValueRange{value, sign});
+    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
+    return success();
+  }
+};
+
 /// Converts math.expm1 to SPIR-V ops.
 ///
 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
 /// these operations.
 template <typename ExpOp>
-class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
-public:
-  using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
+struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
@@ -57,9 +118,8 @@ class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
 /// these operations.
 template <typename LogOp>
-class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
-public:
-  using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
+struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
@@ -83,6 +143,8 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
 namespace mlir {
 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns) {
+  // Core patterns
+  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
 
   // GLSL patterns
   patterns

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
new file mode 100644
index 0000000000000..f2379f2cab5cd
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+func @copy_sign_scalar(%value: f32, %sign: f32) -> f32 {
+  %0 = math.copysign %value, %sign : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: func @copy_sign_scalar
+//  CHECK-SAME: (%[[VALUE:.+]]: f32, %[[SIGN:.+]]: f32)
+//       CHECK:   %[[SMASK:.+]] = spv.Constant -2147483648 : i32
+//       CHECK:   %[[VMASK:.+]] = spv.Constant 2147483647 : i32
+//       CHECK:   %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : f32 to i32
+//       CHECK:   %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : f32 to i32
+//       CHECK:   %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i32
+//       CHECK:   %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i32
+//       CHECK:   %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : i32
+//       CHECK:   %[[RESULT:.+]] = spv.Bitcast %[[OR]] : i32 to f32
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16, Int16], []>, {}> } {
+
+func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vector<3xf16> {
+  %0 = math.copysign %value, %sign : vector<3xf16>
+  return %0: vector<3xf16>
+}
+
+}
+
+// CHECK-LABEL: func @copy_sign_vector
+//  CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
+//       CHECK:   %[[SMASK:.+]] = spv.Constant -32768 : i16
+//       CHECK:   %[[VMASK:.+]] = spv.Constant 32767 : i16
+//       CHECK:   %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
+//       CHECK:   %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
+//       CHECK:   %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
+//       CHECK:   %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
+//       CHECK:   %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>
+//       CHECK:   %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SVMASK]] : vector<3xi16>
+//       CHECK:   %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
+//       CHECK:   %[[RESULT:.+]] = spv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
+//       CHECK:   return %[[RESULT]]


        


More information about the Mlir-commits mailing list