[Mlir-commits] [mlir] 8dae0b6 - [mlir][spirv] arith::RemSIOp OpenCL lowering

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 25 01:44:33 PST 2021


Author: Butygin
Date: 2021-11-25T12:44:06+03:00
New Revision: 8dae0b6b6c9a0256908820b9ef4bc04f0119666b

URL: https://github.com/llvm/llvm-project/commit/8dae0b6b6c9a0256908820b9ef4bc04f0119666b
DIFF: https://github.com/llvm/llvm-project/commit/8dae0b6b6c9a0256908820b9ef4bc04f0119666b.diff

LOG: [mlir][spirv] arith::RemSIOp OpenCL lowering

Differential Revision: https://reviews.llvm.org/D114524

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 6fd69637df1d8..99e15a2f25a11 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -45,11 +45,20 @@ struct ConstantScalarOpPattern final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts arith.remsi to SPIR-V ops.
+/// Converts arith.remsi to GLSL SPIR-V ops.
 ///
 /// This cannot be merged into the template unary/binary pattern due to Vulkan
 /// restrictions over spv.SRem and spv.SMod.
-struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> {
+struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
+  using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
 
   LogicalResult
@@ -396,7 +405,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// RemSIOpPattern
+// RemSIOpGLSLPattern
 //===----------------------------------------------------------------------===//
 
 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
@@ -406,6 +415,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
 /// if either operand can be negative. Emulate it via spv.UMod.
+template <typename SignedAbsOp>
 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
                                     Value signOperand, OpBuilder &builder) {
   assert(lhs.getType() == rhs.getType());
@@ -414,8 +424,8 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
   Type type = lhs.getType();
 
   // Calculate the remainder with spv.UMod.
-  Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
-  Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
+  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
 
   // Fix the sign.
@@ -429,11 +439,26 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
 }
 
 LogicalResult
-RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0],
-                                        adaptor.getOperands()[1],
-                                        adaptor.getOperands()[0], rewriter);
+RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
+      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+      adaptor.getOperands()[0], rewriter);
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RemSIOpOCLPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                                   ConversionPatternRewriter &rewriter) const {
+  Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
+      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+      adaptor.getOperands()[0], rewriter);
   rewriter.replaceOp(op, result);
 
   return success();
@@ -762,7 +787,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
-    RemSIOpPattern,
+    RemSIOpGLSLPattern, RemSIOpOCLPattern,
     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
     XOrIOpLogicalPattern, XOrIOpBooleanPattern,

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 5ab13e780d0d4..291c7fdd77b1c 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -43,14 +43,8 @@ func @scalar_srem(%lhs: i32, %rhs: i32) {
 // Check float unary operation conversions.
 // CHECK-LABEL: @float32_unary_scalar
 func @float32_unary_scalar(%arg0: f32) {
-  // CHECK: spv.GLSL.FAbs %{{.*}}: f32
-  %0 = math.abs %arg0 : f32
-  // CHECK: spv.GLSL.Ceil %{{.*}}: f32
-  %1 = math.ceil %arg0 : f32
   // CHECK: spv.FNegate %{{.*}}: f32
-  %5 = arith.negf %arg0 : f32
-  // CHECK: spv.GLSL.Floor %{{.*}}: f32
-  %10 = math.floor %arg0 : f32
+  %0 = arith.negf %arg0 : f32
   return
 }
 
@@ -842,3 +836,39 @@ func @sitofp(%arg0 : i64) -> f64 {
 }
 
 } // end module
+
+// -----
+
+// Check OpenCL lowering of arith.remsi
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int16, Kernel], []>, {}>
+} {
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+  // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : i32
+  // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : i32
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+  %0 = arith.remsi %lhs, %rhs: i32
+  return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+  // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : vector<3xi16>
+  // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : vector<3xi16>
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+  %0 = arith.remsi %arg0, %arg1: vector<3xi16>
+  return
+}
+
+} // end module

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index ad32a88a876ea..8cae1ca7d94ef 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
   %6 = math.tanh %arg0 : f32
   // CHECK: spv.GLSL.Sin %{{.*}}: f32
   %7 = math.sin %arg0 : f32
+  // CHECK: spv.GLSL.FAbs %{{.*}}: f32
+  %8 = math.abs %arg0 : f32
+  // CHECK: spv.GLSL.Ceil %{{.*}}: f32
+  %9 = math.ceil %arg0 : f32
+  // CHECK: spv.GLSL.Floor %{{.*}}: f32
+  %10 = math.floor %arg0 : f32
   return
 }
 

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 8a1a3acc5f0cd..5bfd4e477c21c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
   %6 = math.tanh %arg0 : f32
   // CHECK: spv.OCL.sin %{{.*}}: f32
   %7 = math.sin %arg0 : f32
+  // CHECK: spv.OCL.fabs %{{.*}}: f32
+  %8 = math.abs %arg0 : f32
+  // CHECK: spv.OCL.ceil %{{.*}}: f32
+  %9 = math.ceil %arg0 : f32
+  // CHECK: spv.OCL.floor %{{.*}}: f32
+  %10 = math.floor %arg0 : f32
   return
 }
 


        


More information about the Mlir-commits mailing list