[Mlir-commits] [mlir] 8187b8c - [mlir] Rename arith.addi_carry to arith.addui_carry

Jakub Kuderski llvmlistbot at llvm.org
Wed Aug 24 07:43:47 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-24T10:41:06-04:00
New Revision: 8187b8c42b07354995f89baeb5583ffb9f83f82e

URL: https://github.com/llvm/llvm-project/commit/8187b8c42b07354995f89baeb5583ffb9f83f82e
DIFF: https://github.com/llvm/llvm-project/commit/8187b8c42b07354995f89baeb5583ffb9f83f82e.diff

LOG: [mlir] Rename arith.addi_carry to arith.addui_carry

The intention is to have this op lowered to
`llvm.intr.uadd.with.overflow` or `spv.IAddCarry`. LLVM has a second
intrinsic for signed add-with-overflow, `llvm.intr.sadd.with.overflow`,
with different semantics. Therefore we should have 2 ops with `arith`,
and be explicit about signed/unsigned semantics.

Rename `arith.addi_carry` to `arith.addui_carry` before we introduce a
signed version of this op: `arith.addsi_carry`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132491

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/Dialect/Arithmetic/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 958e9da261157..f6e8a9af02b44 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -203,24 +203,26 @@ def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
 }
 
 
-def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative,
+def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Commutative,
     AllTypesMatch<["lhs", "rhs", "sum"]>]> {
-  let summary = "integer addition operation returning both the sum and carry";
+  let summary = "unsigned integer addition operation returning sum and carry";
   let description = [{
-    The `addi_carry` operation takes two operands and returns two results: the
-    sum (same type as both operands), and the carry (boolean-like).
+    The `addui_carry` operation takes two operands and returns two results: the
+    sum (same type as both operands), and the carry (boolean-like). The carry
+    value `1` indicates unsigned addition overflow, while indicates `0` no
+    overflow.
 
     Example:
 
     ```mlir
     // Scalar addition.
-    %sum, %carry = arith.addi_carry %b, %c : i64, i1
+    %sum, %carry = arith.addui_carry %b, %c : i64, i1
 
     // Vector element-wise addition.
-    %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1>
+    %b:2 = arith.addui_carry %g, %h : vector<4xi32>, vector<4xi1>
 
     // Tensor element-wise addition.
-    %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
+    %c:2 = arith.addui_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
     ```
   }];
 

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 56a241cd45122..9b7621f2b6693 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -195,12 +195,13 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts arith.addi_carry to spv.IAddCarry.
-class AddICarryOpPattern final : public OpConversionPattern<arith::AddICarryOp> {
+/// Converts arith.addui_carry to spv.IAddCarry.
+class AddICarryOpPattern final
+    : public OpConversionPattern<arith::AddUICarryOp> {
 public:
-  using OpConversionPattern<arith::AddICarryOp>::OpConversionPattern;
+  using OpConversionPattern<arith::AddUICarryOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+  matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -850,7 +851,7 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+AddICarryOpPattern::matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
                                     ConversionPatternRewriter &rewriter) const {
   Type dstElemTy = adaptor.getLhs().getType();
   auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy});
@@ -866,8 +867,7 @@ AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
 
   // Convert the carry value to boolean.
   Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
-  Value carryResult =
-      rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
 
   rewriter.replaceOp(op, {sumResult, carryResult});
   return success();

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 11e64ffd8f720..e57422b9d341e 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -218,10 +218,10 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 //===----------------------------------------------------------------------===//
-// AddICarryOp
+// AddUICarryOp
 //===----------------------------------------------------------------------===//
 
-Optional<SmallVector<int64_t, 4>> arith::AddICarryOp::getShapeForUnroll() {
+Optional<SmallVector<int64_t, 4>> arith::AddUICarryOp::getShapeForUnroll() {
   if (auto vt = getType(0).dyn_cast<VectorType>())
     return llvm::to_vector<4>(vt.getShape());
   return None;
@@ -233,10 +233,11 @@ static APInt calculateCarry(const APInt &sum, const APInt &operand) {
   return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
 }
 
-LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
-                                       SmallVectorImpl<OpFoldResult> &results) {
+LogicalResult
+arith::AddUICarryOp::fold(ArrayRef<Attribute> operands,
+                          SmallVectorImpl<OpFoldResult> &results) {
   auto carryTy = getCarry().getType();
-  // addi_carry(x, 0) -> x, false
+  // addui_carry(x, 0) -> x, false
   if (matchPattern(getRhs(), m_Zero())) {
     auto carryZero = APInt::getZero(1);
     Builder builder(getContext());
@@ -247,7 +248,7 @@ LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
     return success();
   }
 
-  // addi_carry(constant_a, constant_b) -> constant_sum, constant_carry
+  // addui_carry(constant_a, constant_b) -> constant_sum, constant_carry
   // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
   // operands. If that succeeds, calculate the carry boolean based on the sum
   // and the first (constant) operand, `lhs`. Note that we cannot simply call

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index ca48648b7c1dd..4ed91ca31abb6 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -73,29 +73,29 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) {
 }
 
 // Check integer add-with-carry conversions.
-// CHECK-LABEL: @int32_scalar_addi_carry
+// CHECK-LABEL: @int32_scalar_addui_carry
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
-func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+func.func @int32_scalar_addui_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
   // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)>
   // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)>
   // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)>
   // CHECK-DAG:  %[[ONE:.+]] = spv.Constant 1 : i32
   // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : i32
   // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
-  %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1
+  %sum, %carry = arith.addui_carry %lhs, %rhs: i32, i1
   return %sum, %carry : i32, i1
 }
 
-// CHECK-LABEL: @int32_vector_addi_carry
+// CHECK-LABEL: @int32_vector_addui_carry
 // CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
-func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+func.func @int32_vector_addui_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
   // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
   // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
   // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
   // CHECK-DAG:  %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32>
   // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
   // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
-  %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addui_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
   return %sum, %carry : vector<4xi32>, vector<4xi1>
 }
 

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index a1ab1be5f63d9..7076ab81e6233 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -549,7 +549,7 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
 //  CHECK-NEXT:   return %arg0, %[[false]]
 func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
   %zero = arith.constant 0 : i32
-  %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1
+  %sum, %carry = arith.addui_carry %arg0, %zero: i32, i1
   return %sum, %carry : i32, i1
 }
 
@@ -558,7 +558,7 @@ func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
 //  CHECK-NEXT:   return %arg0, %[[false]]
 func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
   %zero = arith.constant dense<0> : vector<4xi32>
-  %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addui_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
   return %sum, %carry : vector<4xi32>, vector<4xi1>
 }
 
@@ -567,7 +567,7 @@ func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector
 //  CHECK-NEXT:   return %arg0, %[[false]]
 func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
   %zero = arith.constant 0 : i32
-  %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1
+  %sum, %carry = arith.addui_carry %zero, %arg0: i32, i1
   return %sum, %carry : i32, i1
 }
 
@@ -578,7 +578,7 @@ func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
 func.func @addiCarryConstants() -> (i32, i1) {
   %c13 = arith.constant 13 : i32
   %c37 = arith.constant 37 : i32
-  %sum, %carry = arith.addi_carry %c13, %c37: i32, i1
+  %sum, %carry = arith.addui_carry %c13, %c37: i32, i1
   return %sum, %carry : i32, i1
 }
 
@@ -589,7 +589,7 @@ func.func @addiCarryConstants() -> (i32, i1) {
 func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
   %max = arith.constant 4294967295 : i32
   %c1 = arith.constant 1 : i32
-  %sum, %carry = arith.addi_carry %max, %c1: i32, i1
+  %sum, %carry = arith.addui_carry %max, %c1: i32, i1
   return %sum, %carry : i32, i1
 }
 
@@ -599,7 +599,7 @@ func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
 // CHECK-NEXT:    return %[[c_2]], %[[true]]
 func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
   %max = arith.constant 4294967295 : i32
-  %sum, %carry = arith.addi_carry %max, %max: i32, i1
+  %sum, %carry = arith.addui_carry %max, %max: i32, i1
   return %sum, %carry : i32, i1
 }
 
@@ -610,7 +610,7 @@ func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
 func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
   %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
   %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
-  %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
   return %sum, %carry : vector<4xi32>, vector<4xi1>
 }
 
@@ -621,7 +621,7 @@ func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
 func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
   %v1 = arith.constant dense<1> : vector<4xi32>
   %v2 = arith.constant dense<2> : vector<4xi32>
-  %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
   return %sum, %carry : vector<4xi32>, vector<4xi1>
 }
 

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 2ae8dd2123959..93307562c2de3 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -111,32 +111,32 @@ func.func @func_with_ops(f32) {
 // -----
 
 func.func @func_with_ops(%a: f32) {
-  // expected-error at +1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}}
-  %r:2 = arith.addi_carry %a, %a : f32, i32
+  // expected-error at +1 {{'arith.addui_carry' op operand #0 must be signless-integer-like}}
+  %r:2 = arith.addui_carry %a, %a : f32, i32
   return
 }
 
 // -----
 
 func.func @func_with_ops(%a: i32) {
-  // expected-error at +1 {{'arith.addi_carry' op result #1 must be bool-like}}
-  %r:2 = arith.addi_carry %a, %a : i32, i32
+  // expected-error at +1 {{'arith.addui_carry' op result #1 must be bool-like}}
+  %r:2 = arith.addui_carry %a, %a : i32, i32
   return
 }
 
 // -----
 
 func.func @func_with_ops(%a: vector<8xi32>) {
-  // expected-error at +1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}}
-  %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1
+  // expected-error at +1 {{'arith.addui_carry' op if an operand is non-scalar, then all results must be non-scalar}}
+  %r:2 = arith.addui_carry %a, %a : vector<8xi32>, i1
   return
 }
 
 // -----
 
 func.func @func_with_ops(%a: vector<8xi32>) {
-  // expected-error at +1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}}
-  %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1>
+  // expected-error at +1 {{'arith.addui_carry' op all non-scalar operands/results must have the same shape and base type}}
+  %r:2 = arith.addui_carry %a, %a : vector<8xi32>, tensor<8xi1>
   return
 }
 

diff  --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index e9bb19838458f..56e17c798f831 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -25,27 +25,27 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
   return %0 : vector<[8]xi64>
 }
 
-// CHECK-LABEL: test_addi_carry
-func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 {
-  %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1
+// CHECK-LABEL: test_addui_carry
+func.func @test_addui_carry(%arg0 : i64, %arg1 : i64) -> i64 {
+  %sum, %carry = arith.addui_carry %arg0, %arg1 : i64, i1
   return %sum : i64
 }
 
-// CHECK-LABEL: test_addi_carry_tensor
-func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
-  %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
+// CHECK-LABEL: test_addui_carry_tensor
+func.func @test_addui_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+  %sum, %carry = arith.addui_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
   return %sum : tensor<8x8xi64>
 }
 
-// CHECK-LABEL: test_addi_carry_vector
-func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
-  %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
+// CHECK-LABEL: test_addui_carry_vector
+func.func @test_addui_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+  %0:2 = arith.addui_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
   return %0#0 : vector<8xi64>
 }
 
-// CHECK-LABEL: test_addi_carry_scalable_vector
-func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
-  %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
+// CHECK-LABEL: test_addui_carry_scalable_vector
+func.func @test_addui_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0:2 = arith.addui_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
   return %0#0 : vector<[8]xi64>
 }
 


        


More information about the Mlir-commits mailing list