[Mlir-commits] [mlir] 4309170 - [mlir] Add arith.addi_carry op

Jakub Kuderski llvmlistbot at llvm.org
Wed Aug 17 14:03:00 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-17T17:01:20-04:00
New Revision: 4309170c871d3e17760b36a54936442e97cf265f

URL: https://github.com/llvm/llvm-project/commit/4309170c871d3e17760b36a54936442e97cf265f
DIFF: https://github.com/llvm/llvm-project/commit/4309170c871d3e17760b36a54936442e97cf265f.diff

LOG: [mlir] Add arith.addi_carry op

The `arith.addi_carry` op implements integer addition with overflows. The carry is returned via the second result, as `i1`.

Reviewed By: antiagainst, bondhugula

Differential Revision: https://reviews.llvm.org/D131893

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/Dialect/Arithmetic/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 9eb71cdde74c6..958e9da261157 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -202,6 +202,41 @@ def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
   let hasCanonicalizer = 1;
 }
 
+
+def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative,
+    AllTypesMatch<["lhs", "rhs", "sum"]>]> {
+  let summary = "integer addition operation returning both the sum and carry";
+  let description = [{
+    The `addi_carry` operation takes two operands and returns two results: the
+    sum (same type as both operands), and the carry (boolean-like).
+
+    Example:
+
+    ```mlir
+    // Scalar addition.
+    %sum, %carry = arith.addi_carry %b, %c : i64, i1
+
+    // Vector element-wise addition.
+    %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1>
+
+    // Tensor element-wise addition.
+    %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
+    ```
+  }];
+
+  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+  let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry);
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry)
+  }];
+
+  let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // SubIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 537dc6cbe96ed..8546a373be4cd 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <cassert>
 #include <utility>
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -15,9 +16,9 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/SmallString.h"
 
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/SmallString.h"
 
 using namespace mlir;
 using namespace mlir::arith;
@@ -216,6 +217,81 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
       context);
 }
 
+//===----------------------------------------------------------------------===//
+// AddICarryOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::AddICarryOp::getShapeForUnroll() {
+  if (auto vt = getType(0).dyn_cast<VectorType>())
+    return llvm::to_vector<4>(vt.getShape());
+  return None;
+}
+
+// Returns the carry bit, assuming that `sum` is the result of addition of
+// `operand` and another number.
+static APInt calculateCarry(const APInt &sum, const APInt &operand) {
+  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
+}
+
+LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
+                                       SmallVectorImpl<OpFoldResult> &results) {
+  auto carryTy = getCarry().getType();
+  // addi_carry(x, 0) -> x, false
+  if (matchPattern(getRhs(), m_Zero())) {
+    auto carryZero = APInt::getZero(1);
+    Builder builder(getContext());
+    auto falseValue = builder.getZeroAttr(carryTy);
+
+    results.push_back(getLhs());
+    results.push_back(falseValue);
+    return success();
+  }
+
+  // addi_carry(constant_a, constant_b) -> constant_sum, constant_carry
+  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
+  // operands. If that succeeds, calculate the carry boolean based on the sum
+  // and the first (constant) operand, `lhs`. Note that we cannot simply call
+  // `constFoldBinaryOp` again to calculate the carry (bit) because the
+  // constructed attribute is of the same element type as both operands.
+  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
+          operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
+    Attribute carryAttr;
+    if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+      // Both arguments are scalars, calculate the scalar carry value.
+      auto sum = sumAttr.cast<IntegerAttr>();
+      carryAttr = IntegerAttr::get(
+          carryTy, calculateCarry(sum.getValue(), lhs.getValue()));
+    } else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
+      // Both arguments are splats, calculate the splat carry value.
+      auto sum = sumAttr.cast<SplatElementsAttr>();
+      APInt carry = calculateCarry(sum.getSplatValue<APInt>(),
+                                   lhs.getSplatValue<APInt>());
+      carryAttr = SplatElementsAttr::get(carryTy, carry);
+    } else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
+      // Othwerwise calculate element-wise carry values.
+      auto sum = sumAttr.cast<ElementsAttr>();
+      const auto numElems = static_cast<size_t>(sum.getNumElements());
+      SmallVector<APInt> carryValues;
+      carryValues.reserve(numElems);
+
+      auto sumIt = sum.value_begin<APInt>();
+      auto lhsIt = lhs.value_begin<APInt>();
+      for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
+        carryValues.push_back(calculateCarry(*sumIt, *lhsIt));
+
+      carryAttr = DenseElementsAttr::get(carryTy, carryValues);
+    } else {
+      return failure();
+    }
+
+    results.push_back(sumAttr);
+    results.push_back(carryAttr);
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // SubIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 0a4d08ae071af..f99a0702f9905 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -544,6 +544,87 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// CHECK-LABEL: @addiCarryZeroRhs
+//  CHECK-NEXT:   %[[false:.+]] = arith.constant false
+//  CHECK-NEXT:   return %arg0, %[[false]]
+func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
+  %zero = arith.constant 0 : i32
+  %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryZeroRhsSplat
+//  CHECK-NEXT:   %[[false:.+]] = arith.constant dense<false> : vector<4xi1>
+//  CHECK-NEXT:   return %arg0, %[[false]]
+func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+  %zero = arith.constant dense<0> : vector<4xi32>
+  %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
+  return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
+// CHECK-LABEL: @addiCarryZeroLhs
+//  CHECK-NEXT:   %[[false:.+]] = arith.constant false
+//  CHECK-NEXT:   return %arg0, %[[false]]
+func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
+  %zero = arith.constant 0 : i32
+  %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstants
+//  CHECK-DAG:    %[[false:.+]] = arith.constant false
+//  CHECK-DAG:    %[[c50:.+]] = arith.constant 50 : i32
+//  CHECK-NEXT:   return %[[c50]], %[[false]]
+func.func @addiCarryConstants() -> (i32, i1) {
+  %c13 = arith.constant 13 : i32
+  %c37 = arith.constant 37 : i32
+  %sum, %carry = arith.addi_carry %c13, %c37: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflow1
+//  CHECK-DAG:    %[[true:.+]] = arith.constant true
+//  CHECK-DAG:    %[[c0:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[c0]], %[[true]]
+func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
+  %max = arith.constant 4294967295 : i32
+  %c1 = arith.constant 1 : i32
+  %sum, %carry = arith.addi_carry %max, %c1: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflow2
+//  CHECK-DAG:    %[[true:.+]] = arith.constant true
+//  CHECK-DAG:    %[[c_2:.+]] = arith.constant -2 : i32
+// CHECK-NEXT:    return %[[c_2]], %[[true]]
+func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
+  %max = arith.constant 4294967295 : i32
+  %sum, %carry = arith.addi_carry %max, %max: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflowVector
+//  CHECK-DAG:    %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32>
+//  CHECK-DAG:    %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1>
+// CHECK-NEXT:    return %[[sum]], %[[carry]]
+func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
+  %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
+  %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
+  %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+  return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
+// CHECK-LABEL: @addiCarryConstantsSplatVector
+//   CHECK-DAG:   %[[sum:.+]] = arith.constant dense<3> : vector<4xi32>
+//   CHECK-DAG:   %[[carry:.+]] = arith.constant dense<false> : vector<4xi1>
+//  CHECK-NEXT:   return %[[sum]], %[[carry]]
+func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
+  %v1 = arith.constant dense<1> : vector<4xi32>
+  %v2 = arith.constant dense<2> : vector<4xi32>
+  %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+  return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
 // CHECK-LABEL: @notCmpEQ
 //       CHECK:   %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
 //       CHECK:   return %[[cres]]

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 19c427b5e744f..2ae8dd2123959 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -110,6 +110,38 @@ func.func @func_with_ops(f32) {
 
 // -----
 
+func.func @func_with_ops(%a: f32) {
+  // expected-error at +1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}}
+  %r:2 = arith.addi_carry %a, %a : f32, i32
+  return
+}
+
+// -----
+
+func.func @func_with_ops(%a: i32) {
+  // expected-error at +1 {{'arith.addi_carry' op result #1 must be bool-like}}
+  %r:2 = arith.addi_carry %a, %a : i32, i32
+  return
+}
+
+// -----
+
+func.func @func_with_ops(%a: vector<8xi32>) {
+  // expected-error at +1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}}
+  %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1
+  return
+}
+
+// -----
+
+func.func @func_with_ops(%a: vector<8xi32>) {
+  // expected-error at +1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}}
+  %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1>
+  return
+}
+
+// -----
+
 func.func @func_with_ops(i32) {
 ^bb0(%a : i32):
   %sf = arith.addf %a, %a : i32  // expected-error {{'arith.addf' op operand #0 must be floating-point-like}}

diff  --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 61f9a2d8e5125..e9bb19838458f 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -25,6 +25,30 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_addi_carry
+func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 {
+  %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1
+  return %sum : i64
+}
+
+// CHECK-LABEL: test_addi_carry_tensor
+func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+  %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
+  return %sum : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_addi_carry_vector
+func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+  %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
+  return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_addi_carry_scalable_vector
+func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
+  return %0#0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_subi
 func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.subi %arg0, %arg1 : i64


        


More information about the Mlir-commits mailing list