[Mlir-commits] [mlir] 4ba45a7 - [mlir][StandardToSPIRV] Fix conversion for signed remainder

Lei Zhang llvmlistbot at llvm.org
Mon Jul 13 13:18:53 PDT 2020


Author: Lei Zhang
Date: 2020-07-13T16:15:31-04:00
New Revision: 4ba45a778a13eab1495a75a14682f874016f3d21

URL: https://github.com/llvm/llvm-project/commit/4ba45a778a13eab1495a75a14682f874016f3d21
DIFF: https://github.com/llvm/llvm-project/commit/4ba45a778a13eab1495a75a14682f874016f3d21.diff

LOG: [mlir][StandardToSPIRV] Fix conversion for signed remainder

Per the Vulkan's SPIR-V environment spec, "for the OpSRem and OpSMod
instructions, if either operand is negative the result is undefined."
So we cannot directly use spv.SRem/spv.SMod if either operand can be
negative. Emulate it via spv.UMod.

Because the emulation uses spv.SNegate, this commit also defines
spv.SNegate.

Differential Revision: https://reviews.llvm.org/D83679

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
    mlir/test/Dialect/SPIRV/arithmetic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index 350e3659a28d..5a12e6f36ec4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -452,6 +452,31 @@ def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod", SPV_Integer, []> {
 
 // -----
 
+def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate", SPV_Integer, []> {
+  let summary = "Signed-integer subtract of Operand from zero.";
+
+  let description = [{
+    Result Type must be a scalar or vector of integer type.
+
+    Operand’s type  must be a scalar or vector of integer type.  It must
+    have the same number of components as Result Type.  The component width
+    must equal the component width in Result Type.
+
+     Results are computed per component.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %1 = spv.SNegate %0 : i32
+    %3 = spv.SNegate %2 : vector<4xi32>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
   let summary = [{
     Signed remainder operation for the remainder whose sign matches the sign

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index f114c878569d..cbff82efdfd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3150,6 +3150,7 @@ def SPV_OC_OpUConvert                  : I32EnumAttrCase<"OpUConvert", 113>;
 def SPV_OC_OpSConvert                  : I32EnumAttrCase<"OpSConvert", 114>;
 def SPV_OC_OpFConvert                  : I32EnumAttrCase<"OpFConvert", 115>;
 def SPV_OC_OpBitcast                   : I32EnumAttrCase<"OpBitcast", 124>;
+def SPV_OC_OpSNegate                   : I32EnumAttrCase<"OpSNegate", 126>;
 def SPV_OC_OpFNegate                   : I32EnumAttrCase<"OpFNegate", 127>;
 def SPV_OC_OpIAdd                      : I32EnumAttrCase<"OpIAdd", 128>;
 def SPV_OC_OpFAdd                      : I32EnumAttrCase<"OpFAdd", 129>;
@@ -3271,41 +3272,42 @@ def SPV_OpcodeAttr :
       SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
       SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
       SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
-      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
-      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
-      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
-      SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
-      SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
-      SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
-      SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
-      SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
-      SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
-      SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
-      SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
-      SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
-      SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
-      SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
-      SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
-      SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
-      SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
-      SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
-      SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
-      SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
-      SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
-      SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
-      SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
-      SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
-      SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
-      SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
-      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
-      SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
-      SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
-      SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
-      SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
-      SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
-      SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
-      SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
+      SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
+      SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
+      SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
+      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
+      SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
+      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+      SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
+      SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
+      SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
+      SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
+      SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+      SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
+      SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
+      SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
+      SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
+      SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
+      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
+      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
+      SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
+      SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
+      SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
+      SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
+      SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
+      SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
+      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
+      SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
+      SPV_OC_OpCooperativeMatrixLengthNV
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index e1b477126a02..9789122809ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -26,6 +26,12 @@ class SPV_LogicalBinaryOp<string mnemonic, Type operandsType,
                                 SameOperandsAndResultShape])> {
   let parser = [{ return ::parseLogicalBinaryOp(parser, result); }];
   let printer = [{ return ::printLogicalOp(getOperation(), p); }];
+
+  let builders = [
+    OpBuilder<
+      "OpBuilder &builder, OperationState &state, Value lhs, Value rhs",
+      "::buildLogicalBinaryOp(builder, state, lhs, rhs);">
+  ];
 }
 
 class SPV_LogicalUnaryOp<string mnemonic, Type operandType,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 6bb7a17ae46f..dad8bfc0173f 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -97,6 +97,35 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
   return builder.getF32FloatAttr(dstVal.convertToFloat());
 }
 
+/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
+/// the sign of `signOperand`.
+///
+/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
+/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
+/// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
+/// if either operand can be negative. Emulate it via spv.UMod.
+static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
+                                    Value signOperand, OpBuilder &builder) {
+  assert(lhs.getType() == rhs.getType());
+  assert(lhs == signOperand || rhs == signOperand);
+
+  Type type = lhs.getType();
+
+  // Calculate the remainder with spv.UMod.
+  Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
+  Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+
+  // Fix the sign.
+  Value isPositive;
+  if (lhs == signOperand)
+    isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+  else
+    isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
+  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
+  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+}
+
 /// Returns the offset of the value in `targetBits` representation. `srcIdx` is
 /// an index into a 1-D array with each element having `sourceBits`. When
 /// accessing an element in the array treating as having elements of
@@ -308,6 +337,19 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
   }
 };
 
+/// Converts std.remi_signed to SPIR-V ops.
+///
+/// This cannot be merged into the template unary/binary pattern due to
+/// Vulkan restrictions over spv.SRem and spv.SMod.
+class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
+public:
+  using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts bitwise standard operations to SPIR-V operations. This is a special
 /// pattern other than the BinaryOpPatternPattern because if the operands are
 /// boolean values, SPIR-V uses 
diff erent operations (`SPIRVLogicalOp`). For
@@ -506,6 +548,20 @@ class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// SignedRemIOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult SignedRemIOpPattern::matchAndRewrite(
+    SignedRemIOp remOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
+                                        operands[1], operands[0], rewriter);
+  rewriter.replaceOp(remOp, result);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp with composite type.
 //===----------------------------------------------------------------------===//
@@ -1005,6 +1061,9 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
   patterns.insert<
+      // Unary and binary patterns
+      BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
+      BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
       UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
       UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
       UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
@@ -1020,7 +1079,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
       UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
       UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
-      UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
       UnaryAndBinaryOpPattern<SignedShiftRightOp,
                               spirv::ShiftRightArithmeticOp>,
       UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
@@ -1031,19 +1089,28 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
-      AllocOpPattern, DeallocOpPattern,
-      BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
-      BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
-      BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
-      CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
-      ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern,
+      SignedRemIOpPattern, XOrOpPattern,
+
+      // Comparison patterns
+      BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
+
+      // Constant patterns
+      ConstantCompositeOpPattern, ConstantScalarOpPattern,
+
+      // Memory patterns
+      AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
+      LoadOpPattern, StoreOpPattern,
+
+      ReturnOpPattern, SelectOpPattern,
+
+      // Type cast patterns
       ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
       TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
-      TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
-      context, typeConverter);
+      TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
+                                                          typeConverter);
 }
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6c8319224974..9d0570257d42 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -844,6 +844,18 @@ static LogicalResult verifyShiftOp(Operation *op) {
   return success();
 }
 
+static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
+                                 Value lhs, Value rhs) {
+  assert(lhs.getType() == rhs.getType());
+
+  Type boolType = builder.getI1Type();
+  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+    boolType = VectorType::get(vecType.getShape(), boolType);
+  state.addTypes(boolType);
+
+  state.addOperands({lhs, rhs});
+}
+
 //===----------------------------------------------------------------------===//
 // spv.AccessChainOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index c232395d80db..a93bf792b34f 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -22,12 +22,23 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
   %2 = muli %lhs, %rhs: i32
   // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32
   %3 = divi_signed %lhs, %rhs: i32
-  // CHECK: spv.SRem %{{.*}}, %{{.*}}: i32
-  %4 = remi_signed %lhs, %rhs: i32
   // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32
-  %5 = divi_unsigned %lhs, %rhs: i32
+  %4 = divi_unsigned %lhs, %rhs: i32
   // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
-  %6 = remi_unsigned %lhs, %rhs: i32
+  %5 = remi_unsigned %lhs, %rhs: i32
+  return
+}
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+  // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32
+  // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+  %0 = remi_signed %lhs, %rhs: i32
   return
 }
 
@@ -75,13 +86,24 @@ func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
 
 // Check int vector types.
 // CHECK-LABEL: @int_vector234
-func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
   // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8>
   %0 = divi_signed %arg0, %arg0: vector<2xi8>
-  // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi16>
-  %1 = remi_signed %arg1, %arg1: vector<3xi16>
   // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64>
-  %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
+  %1 = divi_unsigned %arg1, %arg1: vector<4xi64>
+  return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+  // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : vector<3xi16>
+  // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : vector<3xi16>
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+  %0 = remi_signed %arg0, %arg1: vector<3xi16>
   return
 }
 
@@ -132,8 +154,8 @@ module attributes {
 func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
   // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
   %0 = divi_signed %arg0, %arg0: vector<2xi8>
-  // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
-  %1 = remi_signed %arg1, %arg1: vector<3xi16>
+  // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<3xi32>
+  %1 = divi_signed %arg1, %arg1: vector<3xi16>
   return
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
index 55c67dafe6bb..9752c0d0e579 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
@@ -71,6 +71,11 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.SMod %arg0, %arg1 : vector<4xi32>
     spv.Return
   }
+  spv.func @snegate(%arg0 : vector<4xi32>) "None" {
+    // CHECK: {{%.*}} = spv.SNegate {{%.*}} : vector<4xi32>
+    %0 = spv.SNegate %arg0 : vector<4xi32>
+    spv.Return
+  }
   spv.func @srem(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" {
     // CHECK: {{%.*}} = spv.SRem {{%.*}}, {{%.*}} : vector<4xi32>
     %0 = spv.SRem %arg0, %arg1 : vector<4xi32>

diff  --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
index 85998fb03efd..de574b1510c9 100644
--- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
@@ -174,6 +174,17 @@ func @smod_scalar(%arg: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.SNegate
+//===----------------------------------------------------------------------===//
+
+func @snegate_scalar(%arg: i32) -> i32 {
+  // CHECK: spv.SNegate
+  %0 = spv.SNegate %arg : i32
+  return %0 : i32
+}
+
+// -----
 //===----------------------------------------------------------------------===//
 // spv.SRem
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list