[Mlir-commits] [mlir] 28246b7 - [mlir][arith] Rename addui_carry to addui_extended
Jakub Kuderski
llvmlistbot at llvm.org
Wed Dec 7 14:18:56 PST 2022
Author: Jakub Kuderski
Date: 2022-12-07T17:15:56-05:00
New Revision: 28246b7e759708e8e667cadef11b6a516c258dc6
URL: https://github.com/llvm/llvm-project/commit/28246b7e759708e8e667cadef11b6a516c258dc6
DIFF: https://github.com/llvm/llvm-project/commit/28246b7e759708e8e667cadef11b6a516c258dc6.diff
LOG: [mlir][arith] Rename addui_carry to addui_extended
The goal is to make the naming of the future `_extended` ops more
consistent. With unsigned addition, the carry value/flag and overflow
bit are the same, but this is not true when it comes to signed addition.
Also rename the second result from `carry` to `overflow`.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139569
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/emulate-wide-int.mlir
mlir/test/Dialect/Arith/invalid.mlir
mlir/test/Dialect/Arith/ops.mlir
mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index cc1801b8f7361..6c7244bc7845c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -222,33 +222,36 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
}
-def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Pure, Commutative,
+def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
AllTypesMatch<["lhs", "rhs", "sum"]>]> {
- let summary = "unsigned integer addition operation returning sum and carry";
+ let summary = [{
+ extended unsigned integer addition operation returning sum and overflow bit
+ }];
+
let description = [{
- The `addui_carry` operation takes two operands and returns two results: the
- sum (same type as both operands), and the carry (boolean-like). The carry
- value `1` indicates unsigned addition overflow, while indicates `0` no
- overflow.
+ Performs (N+1)-bit addition on zero-extended operands. Returns two results:
+ the N-bit sum (same type as both operands), and the overflow bit
+ (boolean-like), where`1` indicates unsigned addition overflow, while `0`
+ indicates no overflow.
Example:
```mlir
// Scalar addition.
- %sum, %carry = arith.addui_carry %b, %c : i64, i1
+ %sum, %overflow = arith.addui_extended %b, %c : i64, i1
// Vector element-wise addition.
- %b:2 = arith.addui_carry %g, %h : vector<4xi32>, vector<4xi1>
+ %d:2 = arith.addui_extended %e, %f : vector<4xi32>, vector<4xi1>
// Tensor element-wise addition.
- %c:2 = arith.addui_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
+ %x:2 = arith.addui_extended %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
```
}];
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
- let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry);
+ let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
- $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry)
+ $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
let builders = [
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 3ad01556b2f69..0289bea88b504 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -133,12 +133,12 @@ using IndexCastOpSILowering =
using IndexCastOpUILowering =
IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
-struct AddUICarryOpLowering
- : public ConvertOpToLLVMPattern<arith::AddUICarryOp> {
+struct AddUIExtendedOpLowering
+ : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
+ matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -223,15 +223,15 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// AddUICarryOpLowering
+// AddUIExtendedOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult AddUICarryOpLowering::matchAndRewrite(
- arith::AddUICarryOp op, OpAdaptor adaptor,
+LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
+ arith::AddUIExtendedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Type operandType = adaptor.getLhs().getType();
Type sumResultType = op.getSum().getType();
- Type carryResultType = op.getCarry().getType();
+ Type overflowResultType = op.getOverflow().getType();
if (!LLVM::isCompatibleType(operandType))
return failure();
@@ -241,16 +241,16 @@ LogicalResult AddUICarryOpLowering::matchAndRewrite(
// Handle the scalar and 1D vector cases.
if (!operandType.isa<LLVM::LLVMArrayType>()) {
- Type newCarryType = typeConverter->convertType(carryResultType);
+ Type newOverflowType = typeConverter->convertType(overflowResultType);
Type structType =
- LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newCarryType});
+ LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
loc, structType, adaptor.getLhs(), adaptor.getRhs());
Value sumExtracted =
rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
- Value carryExtracted =
+ Value overflowExtracted =
rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
- rewriter.replaceOp(op, {sumExtracted, carryExtracted});
+ rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
return success();
}
@@ -374,7 +374,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
AddFOpLowering,
AddIOpLowering,
AndIOpLowering,
- AddUICarryOpLowering,
+ AddUIExtendedOpLowering,
BitcastOpLowering,
ConstantOpLowering,
CmpFOpLowering,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d550e0e33f3e5..a127dd8f4e8a6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -213,13 +213,13 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts arith.addui_carry to spirv.IAddCarry.
-class AddICarryOpPattern final
- : public OpConversionPattern<arith::AddUICarryOp> {
+/// Converts arith.addui_extended to spirv.IAddCarry.
+class AddUIExtendedOpPattern final
+ : public OpConversionPattern<arith::AddUIExtendedOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
+ matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -920,12 +920,12 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// AddICarryOpPattern
+// AddUIExtendedOpPattern
//===----------------------------------------------------------------------===//
-LogicalResult
-AddICarryOpPattern::matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
+LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
+ arith::AddUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
@@ -1040,7 +1040,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
- AddICarryOpPattern, SelectOpPattern,
+ AddUIExtendedOpPattern, SelectOpPattern,
MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 00e23961208d4..0a2a8a9f550be 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -219,75 +219,76 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
//===----------------------------------------------------------------------===//
-// AddUICarryOp
+// AddUIExtendedOp
//===----------------------------------------------------------------------===//
-Optional<SmallVector<int64_t, 4>> arith::AddUICarryOp::getShapeForUnroll() {
+Optional<SmallVector<int64_t, 4>> arith::AddUIExtendedOp::getShapeForUnroll() {
if (auto vt = getType(0).dyn_cast<VectorType>())
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
-// Returns the carry bit, assuming that `sum` is the result of addition of
-// `operand` and another number.
-static APInt calculateCarry(const APInt &sum, const APInt &operand) {
+// Returns the overflow bit, assuming that `sum` is the result of unsigned
+// addition of `operand` and another number.
+static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
}
LogicalResult
-arith::AddUICarryOp::fold(ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- auto carryTy = getCarry().getType();
- // addui_carry(x, 0) -> x, false
+arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto overflowTy = getOverflow().getType();
+ // addui_extended(x, 0) -> x, false
if (matchPattern(getRhs(), m_Zero())) {
- auto carryZero = APInt::getZero(1);
+ auto overflowZero = APInt::getZero(1);
Builder builder(getContext());
- auto falseValue = builder.getZeroAttr(carryTy);
+ auto falseValue = builder.getZeroAttr(overflowTy);
results.push_back(getLhs());
results.push_back(falseValue);
return success();
}
- // addui_carry(constant_a, constant_b) -> constant_sum, constant_carry
+ // addui_overflow(constant_a, constant_b) -> constant_sum, constant_carry
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
- // operands. If that succeeds, calculate the carry boolean based on the sum
+ // operands. If that succeeds, calculate the overflow bit based on the sum
// and the first (constant) operand, `lhs`. Note that we cannot simply call
- // `constFoldBinaryOp` again to calculate the carry (bit) because the
+ // `constFoldBinaryOp` again to calculate the overflow bit because the
// constructed attribute is of the same element type as both operands.
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
- Attribute carryAttr;
+ Attribute overflowAttr;
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
- // Both arguments are scalars, calculate the scalar carry value.
+ // Both arguments are scalars, calculate the scalar overflow value.
auto sum = sumAttr.cast<IntegerAttr>();
- carryAttr = IntegerAttr::get(
- carryTy, calculateCarry(sum.getValue(), lhs.getValue()));
+ overflowAttr = IntegerAttr::get(
+ overflowTy,
+ calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
} else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
- // Both arguments are splats, calculate the splat carry value.
+ // Both arguments are splats, calculate the splat overflow value.
auto sum = sumAttr.cast<SplatElementsAttr>();
- APInt carry = calculateCarry(sum.getSplatValue<APInt>(),
- lhs.getSplatValue<APInt>());
- carryAttr = SplatElementsAttr::get(carryTy, carry);
+ APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
+ lhs.getSplatValue<APInt>());
+ overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
} else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
- // Othwerwise calculate element-wise carry values.
+ // Othwerwise calculate element-wise overflow values.
auto sum = sumAttr.cast<ElementsAttr>();
const auto numElems = static_cast<size_t>(sum.getNumElements());
- SmallVector<APInt> carryValues;
- carryValues.reserve(numElems);
+ SmallVector<APInt> overflowValues;
+ overflowValues.reserve(numElems);
auto sumIt = sum.value_begin<APInt>();
auto lhsIt = lhs.value_begin<APInt>();
for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
- carryValues.push_back(calculateCarry(*sumIt, *lhsIt));
+ overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt));
- carryAttr = DenseElementsAttr::get(carryTy, carryValues);
+ overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues);
} else {
return failure();
}
results.push_back(sumAttr);
- results.push_back(carryAttr);
+ results.push_back(overflowAttr);
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 28134cf11932f..f10fefba87eef 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -276,11 +276,12 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
auto [rhsElem0, rhsElem1] =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
- auto lowSum = rewriter.create<arith::AddUICarryOp>(loc, lhsElem0, rhsElem0);
- Value carryVal =
- rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getCarry());
+ auto lowSum =
+ rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
+ Value overflowVal =
+ rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
- Value high0 = rewriter.create<arith::AddIOp>(loc, carryVal, lhsElem1);
+ Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
Value resultVec =
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index d8e49a55c2ad7..cf207c283b0d4 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -355,24 +355,24 @@ func.func @bitcast_1d(%arg0: vector<2xf32>) {
// -----
-// CHECK-LABEL: @addui_carry_scalar
+// CHECK-LABEL: @addui_extended_scalar
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i1)
-func.func @addui_carry_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) {
+func.func @addui_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) {
// CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (i32, i32) -> !llvm.struct<(i32, i1)>
// CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(i32, i1)>
// CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(i32, i1)>
- %sum, %carry = arith.addui_carry %arg0, %arg1 : i32, i1
+ %sum, %carry = arith.addui_extended %arg0, %arg1 : i32, i1
// CHECK-NEXT: return [[SUM]], [[CARRY]] : i32, i1
return %sum, %carry : i32, i1
}
-// CHECK-LABEL: @addui_carry_vector1d
+// CHECK-LABEL: @addui_extended_vector1d
// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi16>, [[ARG1:%.+]]: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>)
-func.func @addui_carry_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) {
+func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) {
// CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (vector<3xi16>, vector<3xi16>) -> !llvm.struct<(vector<3xi16>, vector<3xi1>)>
// CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(vector<3xi16>, vector<3xi1>)>
// CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(vector<3xi16>, vector<3xi1>)>
- %sum, %carry = arith.addui_carry %arg0, %arg1 : vector<3xi16>, vector<3xi1>
+ %sum, %carry = arith.addui_extended %arg0, %arg1 : vector<3xi16>, vector<3xi1>
// CHECK-NEXT: return [[SUM]], [[CARRY]] : vector<3xi16>, vector<3xi1>
return %sum, %carry : vector<3xi16>, vector<3xi1>
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index beb52c5f3402f..938bafa357cf3 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -73,30 +73,30 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) {
}
// Check integer add-with-carry conversions.
-// CHECK-LABEL: @int32_scalar_addui_carry
+// CHECK-LABEL: @int32_scalar_addui_extended
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
-func.func @int32_scalar_addui_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+func.func @int32_scalar_addui_extended(%lhs: i32, %rhs: i32) -> (i32, i1) {
// CHECK-NEXT: %[[IAC:.+]] = spirv.IAddCarry %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)>
// CHECK-DAG: %[[SUM:.+]] = spirv.CompositeExtract %[[IAC]][0 : i32] : !spirv.struct<(i32, i32)>
// CHECK-DAG: %[[C0:.+]] = spirv.CompositeExtract %[[IAC]][1 : i32] : !spirv.struct<(i32, i32)>
// CHECK-DAG: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK-NEXT: %[[C1:.+]] = spirv.IEqual %[[C0]], %[[ONE]] : i32
// CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
- %sum, %carry = arith.addui_carry %lhs, %rhs: i32, i1
- return %sum, %carry : i32, i1
+ %sum, %overflow = arith.addui_extended %lhs, %rhs: i32, i1
+ return %sum, %overflow : i32, i1
}
-// CHECK-LABEL: @int32_vector_addui_carry
+// CHECK-LABEL: @int32_vector_addui_extended
// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
-func.func @int32_vector_addui_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
// CHECK-NEXT: %[[IAC:.+]] = spirv.IAddCarry %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
// CHECK-DAG: %[[SUM:.+]] = spirv.CompositeExtract %[[IAC]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
// CHECK-DAG: %[[C0:.+]] = spirv.CompositeExtract %[[IAC]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
// CHECK-DAG: %[[ONE:.+]] = spirv.Constant dense<1> : vector<4xi32>
// CHECK-NEXT: %[[C1:.+]] = spirv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
// CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
- %sum, %carry = arith.addui_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
- return %sum, %carry : vector<4xi32>, vector<4xi1>
+ %sum, %overflow = arith.addui_extended %lhs, %rhs: vector<4xi32>, vector<4xi1>
+ return %sum, %overflow : vector<4xi32>, vector<4xi1>
}
// Check float unary operation conversions.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d2439a246252c..8b41aade95e0d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -640,7 +640,7 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
%zero = arith.constant 0 : i32
- %sum, %carry = arith.addui_carry %arg0, %zero: i32, i1
+ %sum, %carry = arith.addui_extended %arg0, %zero: i32, i1
return %sum, %carry : i32, i1
}
@@ -649,7 +649,7 @@ func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
%zero = arith.constant dense<0> : vector<4xi32>
- %sum, %carry = arith.addui_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
+ %sum, %carry = arith.addui_extended %arg0, %zero: vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}
@@ -658,7 +658,7 @@ func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
%zero = arith.constant 0 : i32
- %sum, %carry = arith.addui_carry %zero, %arg0: i32, i1
+ %sum, %carry = arith.addui_extended %zero, %arg0: i32, i1
return %sum, %carry : i32, i1
}
@@ -669,7 +669,7 @@ func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
func.func @addiCarryConstants() -> (i32, i1) {
%c13 = arith.constant 13 : i32
%c37 = arith.constant 37 : i32
- %sum, %carry = arith.addui_carry %c13, %c37: i32, i1
+ %sum, %carry = arith.addui_extended %c13, %c37: i32, i1
return %sum, %carry : i32, i1
}
@@ -680,7 +680,7 @@ func.func @addiCarryConstants() -> (i32, i1) {
func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
%max = arith.constant 4294967295 : i32
%c1 = arith.constant 1 : i32
- %sum, %carry = arith.addui_carry %max, %c1: i32, i1
+ %sum, %carry = arith.addui_extended %max, %c1: i32, i1
return %sum, %carry : i32, i1
}
@@ -690,7 +690,7 @@ func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
// CHECK-NEXT: return %[[c_2]], %[[true]]
func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
%max = arith.constant 4294967295 : i32
- %sum, %carry = arith.addui_carry %max, %max: i32, i1
+ %sum, %carry = arith.addui_extended %max, %max: i32, i1
return %sum, %carry : i32, i1
}
@@ -701,7 +701,7 @@ func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
%v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
%v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
- %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+ %sum, %carry = arith.addui_extended %v1, %v2 : vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}
@@ -712,7 +712,7 @@ func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
%v1 = arith.constant dense<1> : vector<4xi32>
%v2 = arith.constant dense<2> : vector<4xi32>
- %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
+ %sum, %carry = arith.addui_extended %v1, %v2 : vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 0f85e7a859386..ab47a56dce092 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -100,7 +100,7 @@ func.func @constant_vector() -> vector<3xi64> {
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
-// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_extended [[LOW0]], [[LOW1]] : i32, i1
// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : i1 to i32
// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32
// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32
@@ -118,7 +118,7 @@ func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 {
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
-// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1>
+// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_extended [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1>
// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : vector<4x1xi1> to vector<4x1xi32>
// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4x1xi32>
// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4x1xi32>
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 93307562c2de3..729c86514b03b 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -111,32 +111,32 @@ func.func @func_with_ops(f32) {
// -----
func.func @func_with_ops(%a: f32) {
- // expected-error at +1 {{'arith.addui_carry' op operand #0 must be signless-integer-like}}
- %r:2 = arith.addui_carry %a, %a : f32, i32
+ // expected-error at +1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+ %r:2 = arith.addui_extended %a, %a : f32, i32
return
}
// -----
func.func @func_with_ops(%a: i32) {
- // expected-error at +1 {{'arith.addui_carry' op result #1 must be bool-like}}
- %r:2 = arith.addui_carry %a, %a : i32, i32
+ // expected-error at +1 {{'arith.addui_extended' op result #1 must be bool-like}}
+ %r:2 = arith.addui_extended %a, %a : i32, i32
return
}
// -----
func.func @func_with_ops(%a: vector<8xi32>) {
- // expected-error at +1 {{'arith.addui_carry' op if an operand is non-scalar, then all results must be non-scalar}}
- %r:2 = arith.addui_carry %a, %a : vector<8xi32>, i1
+ // expected-error at +1 {{'arith.addui_extended' op if an operand is non-scalar, then all results must be non-scalar}}
+ %r:2 = arith.addui_extended %a, %a : vector<8xi32>, i1
return
}
// -----
func.func @func_with_ops(%a: vector<8xi32>) {
- // expected-error at +1 {{'arith.addui_carry' op all non-scalar operands/results must have the same shape and base type}}
- %r:2 = arith.addui_carry %a, %a : vector<8xi32>, tensor<8xi1>
+ // expected-error at +1 {{'arith.addui_extended' op all non-scalar operands/results must have the same shape and base type}}
+ %r:2 = arith.addui_extended %a, %a : vector<8xi32>, tensor<8xi1>
return
}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9d5c686d73b50..99a777d3d5f79 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -25,27 +25,27 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
return %0 : vector<[8]xi64>
}
-// CHECK-LABEL: test_addui_carry
-func.func @test_addui_carry(%arg0 : i64, %arg1 : i64) -> i64 {
- %sum, %carry = arith.addui_carry %arg0, %arg1 : i64, i1
+// CHECK-LABEL: test_addui_extended
+func.func @test_addui_extended(%arg0 : i64, %arg1 : i64) -> i64 {
+ %sum, %overflow = arith.addui_extended %arg0, %arg1 : i64, i1
return %sum : i64
}
-// CHECK-LABEL: test_addui_carry_tensor
-func.func @test_addui_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
- %sum, %carry = arith.addui_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
+// CHECK-LABEL: test_addui_extended_tensor
+func.func @test_addui_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+ %sum, %overflow = arith.addui_extended %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
return %sum : tensor<8x8xi64>
}
-// CHECK-LABEL: test_addui_carry_vector
-func.func @test_addui_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
- %0:2 = arith.addui_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
+// CHECK-LABEL: test_addui_extended_vector
+func.func @test_addui_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+ %0:2 = arith.addui_extended %arg0, %arg1 : vector<8xi64>, vector<8xi1>
return %0#0 : vector<8xi64>
}
-// CHECK-LABEL: test_addui_carry_scalable_vector
-func.func @test_addui_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
- %0:2 = arith.addui_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
+// CHECK-LABEL: test_addui_extended_scalable_vector
+func.func @test_addui_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+ %0:2 = arith.addui_extended %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
return %0#0 : vector<[8]xi64>
}
diff --git a/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir b/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir
index bc6151e1d472f..9e14fffd92b46 100644
--- a/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir
+++ b/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir
@@ -21,7 +21,7 @@ func.func @entry() {
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32>
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32>
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32>
-// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_extended [[LOW0]], [[LOW1]] : i32, i1
// CHECK: [[RES:%.+]] = llvm.bitcast {{%.+}} : vector<2xi32> to i64
// CHECK-NEXt: return [[RES]] : i64
func.func @emulate_me_please(%x : i64) -> i64 {
More information about the Mlir-commits
mailing list