[Mlir-commits] [mlir] 5c16eeb - [mlir][spirv] Define spv.IAddCarry

Jakub Kuderski llvmlistbot at llvm.org
Fri Aug 5 13:47:48 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-05T16:45:51-04:00
New Revision: 5c16eeb7ee13ab0b5eb52571998b9494475db301

URL: https://github.com/llvm/llvm-project/commit/5c16eeb7ee13ab0b5eb52571998b9494475db301
DIFF: https://github.com/llvm/llvm-project/commit/5c16eeb7ee13ab0b5eb52571998b9494475db301.diff

LOG: [mlir][spirv] Define spv.IAddCarry

Based on `spv.ISubBorrow` from D127909.
Also resolved some clang-tidy warnings.

Reviewed By: antiagainst, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D131281

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 30fd95f6a1fe5..7ce34dcd5dedb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -310,6 +310,55 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd",
 
 // -----
 
+def SPV_IAddCarryOp : SPV_BinaryOp<"IAddCarry",
+                                   SPV_AnyStruct, SPV_Integer,
+                                   [Commutative, NoSideEffect]> {
+  let summary = [{
+    Integer addition of Operand 1 and Operand 2, including the carry.
+  }];
+
+  let description = [{
+    Result Type must be from OpTypeStruct.  The struct must have two
+    members, and the two members must be the same type.  The member type
+    must be a scalar or vector of integer type, whose Signedness operand is
+    0.
+
+    Operand 1 and Operand 2 must have the same type as the members of Result
+    Type. These are consumed as unsigned integers.
+
+     Results are computed per component.
+
+    Member 0 of the result gets the low-order bits (full component width) of
+    the addition.
+
+    Member 1 of the result gets the high-order (carry) bit of the result of
+    the addition. That is, it gets the value 1 if the addition overflowed
+    the component width, and 0 otherwise.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %2 = spv.IAddCarry %0, %1 : !spv.struct<(i32, i32)>
+    %2 = spv.IAddCarry %0, %1 : !spv.struct<(vector<2xi32>, vector<2xi32>)>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ScalarOrVectorOf<SPV_Integer>:$operand1,
+    SPV_ScalarOrVectorOf<SPV_Integer>:$operand2
+  );
+
+  let results = (outs
+    SPV_AnyStruct:$result
+  );
+
+  let hasVerifier = 1;
+}
+
+// -----
+
 def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul",
                                         SPV_Integer,
                                         [Commutative, UsableInSpecConstantOp]> {

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9199d0141e1a1..985498b42fce5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4094,6 +4094,7 @@ def SPV_OC_OpFMod                      : I32EnumAttrCase<"OpFMod", 141>;
 def SPV_OC_OpVectorTimesScalar         : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPV_OC_OpMatrixTimesScalar         : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
 def SPV_OC_OpMatrixTimesMatrix         : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPV_OC_OpIAddCarry                 : I32EnumAttrCase<"OpIAddCarry", 149>;
 def SPV_OC_OpISubBorrow                : I32EnumAttrCase<"OpISubBorrow", 150>;
 def SPV_OC_OpIsNan                     : I32EnumAttrCase<"OpIsNan", 156>;
 def SPV_OC_OpIsInf                     : I32EnumAttrCase<"OpIsInf", 157>;
@@ -4219,16 +4220,16 @@ def SPV_OpcodeAttr :
       SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
       SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
       SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar,
-      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpISubBorrow,
-      SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered,
-      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
-      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
-      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
-      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
-      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIAddCarry,
+      SPV_OC_OpISubBorrow, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
+      SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index ac78e08d56097..dae92ddaf821d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/OpDefinition.h"
@@ -2840,6 +2839,55 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
   printGroupNonUniformArithmeticOp(*this, p);
 }
 
+//===----------------------------------------------------------------------===//
+// spv.IAddCarryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::IAddCarryOp::verify() {
+  auto resultType = getType().cast<spirv::StructType>();
+  if (resultType.getNumElements() != 2)
+    return emitOpError("expected result struct type containing two members");
+
+  if (!llvm::is_splat(llvm::makeArrayRef(
+          {operand1().getType(), operand2().getType(),
+           resultType.getElementType(0), resultType.getElementType(1)})))
+    return emitOpError(
+        "expected all operand types and struct member types are the same");
+
+  return success();
+}
+
+ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
+                                      OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(operands) || parser.parseColon())
+    return failure();
+
+  Type resultType;
+  SMLoc loc = parser.getCurrentLocation();
+  if (parser.parseType(resultType))
+    return failure();
+
+  auto structType = resultType.dyn_cast<spirv::StructType>();
+  if (!structType || structType.getNumElements() != 2)
+    return parser.emitError(loc, "expected spv.struct type with two members");
+
+  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
+  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
+    return failure();
+
+  result.addTypes(resultType);
+  return success();
+}
+
+void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
+  printer << ' ';
+  printer.printOptionalAttrDict((*this)->getAttrs());
+  printer.printOperands((*this)->getOperands());
+  printer << " : " << getType();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.ISubBorrowOp
 //===----------------------------------------------------------------------===//
@@ -2849,12 +2897,9 @@ LogicalResult spirv::ISubBorrowOp::verify() {
   if (resultType.getNumElements() != 2)
     return emitOpError("expected result struct type containing two members");
 
-  SmallVector<Type, 4> types;
-  types.push_back(operand1().getType());
-  types.push_back(operand2().getType());
-  types.push_back(resultType.getElementType(0));
-  types.push_back(resultType.getElementType(1));
-  if (!llvm::is_splat(types))
+  if (!llvm::is_splat(llvm::makeArrayRef(
+          {operand1().getType(), operand2().getType(),
+           resultType.getElementType(0), resultType.getElementType(1)})))
     return emitOpError(
         "expected all operand types and struct member types are the same");
 
@@ -2862,9 +2907,9 @@ LogicalResult spirv::ISubBorrowOp::verify() {
 }
 
 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
-                                       OperationState &state) {
+                                       OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
-  if (parser.parseOptionalAttrDict(state.attributes) ||
+  if (parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseOperandList(operands) || parser.parseColon())
     return failure();
 
@@ -2878,10 +2923,10 @@ ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
     return parser.emitError(loc, "expected spv.struct type with two members");
 
   SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
-  if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
+  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
     return failure();
 
-  state.addTypes(resultType);
+  result.addTypes(resultType);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 22159b0dde15c..fb2622d95c30f 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -150,6 +150,58 @@ func.func @isub_scalar(%arg: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iadd_carry_scalar
+func.func @iadd_carry_scalar(%arg: i32) -> !spv.struct<(i32, i32)> {
+  // CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(i32, i32)>
+  %0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32)>
+  return %0 : !spv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @iadd_carry_vector
+func.func @iadd_carry_vector(%arg: vector<3xi32>) -> !spv.struct<(vector<3xi32>, vector<3xi32>)> {
+  // CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+  %0 = spv.IAddCarry %arg, %arg : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i32, i32)> {
+  // expected-error @+1 {{expected spv.struct type with two members}}
+  %0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32, i32)>
+  return %0 : !spv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32)> {
+  // expected-error @+1 {{expected result struct type containing two members}}
+  %0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32)>
+  return %0 : !spv.struct<(i32)>
+}
+
+// -----
+
+func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i64)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32, i64)>
+  return %0 : !spv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @iadd_carry(%arg: i64) -> !spv.struct<(i32, i32)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spv.IAddCarry"(%arg, %arg): (i64, i64) -> !spv.struct<(i32, i32)>
+  return %0 : !spv.struct<(i32, i32)>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.ISubBorrow
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list