[Mlir-commits] [mlir] 4309170 - [mlir] Add arith.addi_carry op
Jakub Kuderski
llvmlistbot at llvm.org
Wed Aug 17 14:03:00 PDT 2022
Author: Jakub Kuderski
Date: 2022-08-17T17:01:20-04:00
New Revision: 4309170c871d3e17760b36a54936442e97cf265f
URL: https://github.com/llvm/llvm-project/commit/4309170c871d3e17760b36a54936442e97cf265f
DIFF: https://github.com/llvm/llvm-project/commit/4309170c871d3e17760b36a54936442e97cf265f.diff
LOG: [mlir] Add arith.addi_carry op
The `arith.addi_carry` op implements integer addition with overflows. The carry is returned via the second result, as `i1`.
Reviewed By: antiagainst, bondhugula
Differential Revision: https://reviews.llvm.org/D131893
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
mlir/test/Dialect/Arithmetic/invalid.mlir
mlir/test/Dialect/Arithmetic/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 9eb71cdde74c6..958e9da261157 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -202,6 +202,41 @@ def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
let hasCanonicalizer = 1;
}
+
+def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative,
+ AllTypesMatch<["lhs", "rhs", "sum"]>]> {
+ let summary = "integer addition operation returning both the sum and carry";
+ let description = [{
+ The `addi_carry` operation takes two operands and returns two results: the
+ sum (same type as both operands), and the carry (boolean-like).
+
+ Example:
+
+ ```mlir
+ // Scalar addition.
+ %sum, %carry = arith.addi_carry %b, %c : i64, i1
+
+ // Vector element-wise addition.
+ %b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1>
+
+ // Tensor element-wise addition.
+ %c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
+ ```
+ }];
+
+ let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+ let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry);
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry)
+ }];
+
+ let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 537dc6cbe96ed..8546a373be4cd 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include <cassert>
#include <utility>
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -15,9 +16,9 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/SmallString.h"
using namespace mlir;
using namespace mlir::arith;
@@ -216,6 +217,81 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
context);
}
+//===----------------------------------------------------------------------===//
+// AddICarryOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::AddICarryOp::getShapeForUnroll() {
+ if (auto vt = getType(0).dyn_cast<VectorType>())
+ return llvm::to_vector<4>(vt.getShape());
+ return None;
+}
+
+// Returns the carry bit, assuming that `sum` is the result of addition of
+// `operand` and another number.
+static APInt calculateCarry(const APInt &sum, const APInt &operand) {
+ return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
+}
+
+LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto carryTy = getCarry().getType();
+ // addi_carry(x, 0) -> x, false
+ if (matchPattern(getRhs(), m_Zero())) {
+ auto carryZero = APInt::getZero(1);
+ Builder builder(getContext());
+ auto falseValue = builder.getZeroAttr(carryTy);
+
+ results.push_back(getLhs());
+ results.push_back(falseValue);
+ return success();
+ }
+
+ // addi_carry(constant_a, constant_b) -> constant_sum, constant_carry
+ // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
+ // operands. If that succeeds, calculate the carry boolean based on the sum
+ // and the first (constant) operand, `lhs`. Note that we cannot simply call
+ // `constFoldBinaryOp` again to calculate the carry (bit) because the
+ // constructed attribute is of the same element type as both operands.
+ if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
+ operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
+ Attribute carryAttr;
+ if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+ // Both arguments are scalars, calculate the scalar carry value.
+ auto sum = sumAttr.cast<IntegerAttr>();
+ carryAttr = IntegerAttr::get(
+ carryTy, calculateCarry(sum.getValue(), lhs.getValue()));
+ } else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
+ // Both arguments are splats, calculate the splat carry value.
+ auto sum = sumAttr.cast<SplatElementsAttr>();
+ APInt carry = calculateCarry(sum.getSplatValue<APInt>(),
+ lhs.getSplatValue<APInt>());
+ carryAttr = SplatElementsAttr::get(carryTy, carry);
+ } else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
+ // Othwerwise calculate element-wise carry values.
+ auto sum = sumAttr.cast<ElementsAttr>();
+ const auto numElems = static_cast<size_t>(sum.getNumElements());
+ SmallVector<APInt> carryValues;
+ carryValues.reserve(numElems);
+
+ auto sumIt = sum.value_begin<APInt>();
+ auto lhsIt = lhs.value_begin<APInt>();
+ for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
+ carryValues.push_back(calculateCarry(*sumIt, *lhsIt));
+
+ carryAttr = DenseElementsAttr::get(carryTy, carryValues);
+ } else {
+ return failure();
+ }
+
+ results.push_back(sumAttr);
+ results.push_back(carryAttr);
+ return success();
+ }
+
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 0a4d08ae071af..f99a0702f9905 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -544,6 +544,87 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
return %add : index
}
+// CHECK-LABEL: @addiCarryZeroRhs
+// CHECK-NEXT: %[[false:.+]] = arith.constant false
+// CHECK-NEXT: return %arg0, %[[false]]
+func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
+ %zero = arith.constant 0 : i32
+ %sum, %carry = arith.addi_carry %arg0, %zero: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryZeroRhsSplat
+// CHECK-NEXT: %[[false:.+]] = arith.constant dense<false> : vector<4xi1>
+// CHECK-NEXT: return %arg0, %[[false]]
+func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+ %zero = arith.constant dense<0> : vector<4xi32>
+ %sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
+ return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
+// CHECK-LABEL: @addiCarryZeroLhs
+// CHECK-NEXT: %[[false:.+]] = arith.constant false
+// CHECK-NEXT: return %arg0, %[[false]]
+func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
+ %zero = arith.constant 0 : i32
+ %sum, %carry = arith.addi_carry %zero, %arg0: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstants
+// CHECK-DAG: %[[false:.+]] = arith.constant false
+// CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32
+// CHECK-NEXT: return %[[c50]], %[[false]]
+func.func @addiCarryConstants() -> (i32, i1) {
+ %c13 = arith.constant 13 : i32
+ %c37 = arith.constant 37 : i32
+ %sum, %carry = arith.addi_carry %c13, %c37: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflow1
+// CHECK-DAG: %[[true:.+]] = arith.constant true
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[c0]], %[[true]]
+func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
+ %max = arith.constant 4294967295 : i32
+ %c1 = arith.constant 1 : i32
+ %sum, %carry = arith.addi_carry %max, %c1: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflow2
+// CHECK-DAG: %[[true:.+]] = arith.constant true
+// CHECK-DAG: %[[c_2:.+]] = arith.constant -2 : i32
+// CHECK-NEXT: return %[[c_2]], %[[true]]
+func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
+ %max = arith.constant 4294967295 : i32
+ %sum, %carry = arith.addi_carry %max, %max: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addiCarryConstantsOverflowVector
+// CHECK-DAG: %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32>
+// CHECK-DAG: %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1>
+// CHECK-NEXT: return %[[sum]], %[[carry]]
+func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
+ %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
+ %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
+ %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+ return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
+// CHECK-LABEL: @addiCarryConstantsSplatVector
+// CHECK-DAG: %[[sum:.+]] = arith.constant dense<3> : vector<4xi32>
+// CHECK-DAG: %[[carry:.+]] = arith.constant dense<false> : vector<4xi1>
+// CHECK-NEXT: return %[[sum]], %[[carry]]
+func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
+ %v1 = arith.constant dense<1> : vector<4xi32>
+ %v2 = arith.constant dense<2> : vector<4xi32>
+ %sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+ return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
// CHECK-LABEL: @notCmpEQ
// CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 19c427b5e744f..2ae8dd2123959 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -110,6 +110,38 @@ func.func @func_with_ops(f32) {
// -----
+func.func @func_with_ops(%a: f32) {
+ // expected-error at +1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}}
+ %r:2 = arith.addi_carry %a, %a : f32, i32
+ return
+}
+
+// -----
+
+func.func @func_with_ops(%a: i32) {
+ // expected-error at +1 {{'arith.addi_carry' op result #1 must be bool-like}}
+ %r:2 = arith.addi_carry %a, %a : i32, i32
+ return
+}
+
+// -----
+
+func.func @func_with_ops(%a: vector<8xi32>) {
+ // expected-error at +1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}}
+ %r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1
+ return
+}
+
+// -----
+
+func.func @func_with_ops(%a: vector<8xi32>) {
+ // expected-error at +1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}}
+ %r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1>
+ return
+}
+
+// -----
+
func.func @func_with_ops(i32) {
^bb0(%a : i32):
%sf = arith.addf %a, %a : i32 // expected-error {{'arith.addf' op operand #0 must be floating-point-like}}
diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 61f9a2d8e5125..e9bb19838458f 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -25,6 +25,30 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_addi_carry
+func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 {
+ %sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1
+ return %sum : i64
+}
+
+// CHECK-LABEL: test_addi_carry_tensor
+func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+ %sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
+ return %sum : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_addi_carry_vector
+func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+ %0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
+ return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_addi_carry_scalable_vector
+func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+ %0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
+ return %0#0 : vector<[8]xi64>
+}
+
// CHECK-LABEL: test_subi
func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.subi %arg0, %arg1 : i64
More information about the Mlir-commits
mailing list