Author: Jakub Kuderski
Date: 2022-09-09T16:49:38-04:00
New Revision: c7f64616e9e1f66059bf768324a440bfde98132c

URL: https://github.com/llvm/llvm-project/commit/c7f64616e9e1f66059bf768324a440bfde98132c
DIFF: https://github.com/llvm/llvm-project/commit/c7f64616e9e1f66059bf768324a440bfde98132c.diff

LOG: [mlir][arith] Support wide integer addition emulation

I tested this implementation for all i16 input pairs, when emulating i16
operations with i8 operations.

Reviewed By: antiagainst, Mogball

Differential Revision: https://reviews.llvm.org/D133137




diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index 2d24513e27344..2bb0de30d7a83 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -24,6 +24,10 @@ namespace mlir::arith {
 using namespace mlir;
+// Common Helper Functions
 // Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
 // Treats `value` as a 2*N bits-wide integer.
 // The bottom bits are returned in the first pair element, while the top bits in
@@ -35,6 +39,96 @@ static std::pair<APInt, APInt> getHalves(const APInt &value,
   return {std::move(low), std::move(high)};
+// Returns the type with the last (innermost) dimention reduced to x1.
+// Scalarizes 1D vector inputs to match how we extract/insert vector values,
+// e.g.:
+//   - vector<3x2xi16> --> vector<3x1xi16>
+//   - vector<2xi16>   --> i16
+static Type reduceInnermostDim(VectorType type) {
+  if (type.getShape().size() == 1)
+    return type.getElementType();
+  auto newShape = to_vector(type.getShape());
+  newShape.back() = 1;
+  return VectorType::get(newShape, type.getElementType());
+// Extracts the `input` vector slice with elements at the last dimension offset
+// by `lastOffset`. Returns a value of vector type with the last dimension
+// reduced to x1 or fully scalarized, e.g.:
+//   - vector<3x2xi16> --> vector<3x1xi16>
+//   - vector<2xi16>   --> i16
+static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value input,
+                                 int64_t lastOffset) {
+  llvm::ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+  assert(lastOffset < shape.back() && "Offset out of bounds");
+  // Scalarize the result in case of 1D vectors.
+  if (shape.size() == 1)
+    return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
+  SmallVector<int64_t> offsets(shape.size(), 0);
+  offsets.back() = lastOffset;
+  auto sizes = llvm::to_vector(shape);
+  sizes.back() = 1;
+  SmallVector<int64_t> strides(shape.size(), 1);
+  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
+                                                        sizes, strides);
+// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
+// with the first element at offset 0 and the second element at offset 1.
+static std::pair<Value, Value>
+extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
+                     Value input) {
+  return {extractLastDimSlice(rewriter, loc, input, 0),
+          extractLastDimSlice(rewriter, loc, input, 1)};
+// Inserts the `source` vector slice into the `dest` vector at offset
+// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a
+// 1D vector.
+static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
+                                Location loc, Value source, Value dest,
+                                int64_t lastOffset) {
+  llvm::ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+  assert(lastOffset < shape.back() && "Offset out of bounds");
+  // Handle scalar source.
+  if (source.getType().isa<IntegerType>())
+    return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
+  SmallVector<int64_t> offsets(shape.size(), 0);
+  offsets.back() = lastOffset;
+  SmallVector<int64_t> strides(shape.size(), 1);
+  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
+                                                       offsets, strides);
+// Constructs a new vector of type `resultType` by creating a series of
+// insertions of `resultComponents`, each at the next offset of the last vector
+// dimension.
+// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
+// when `resultComponents` are `vector<...x1xT>`s, the result type is
+// `vector<...xNxT>`, where `N` is the number of `resultComponenets`.
+static Value constructResultVector(ConversionPatternRewriter &rewriter,
+                                   Location loc, VectorType resultType,
+                                   ValueRange resultComponents) {
+  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
+  assert(!resultShape.empty() && "Result expected to have dimentions");
+  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
+         "Wrong number of result components");
+  Value resultVec =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+  for (auto [i, component] : llvm::enumerate(resultComponents))
+    resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
+  return resultVec;
 namespace {
 // ConvertConstant
@@ -94,6 +188,45 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
+// ConvertAddI
+struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Value lhs = adaptor.getLhs();
+    Value rhs = adaptor.getRhs();
+    auto newTy = getTypeConverter()
+                     ->convertType(op.getType())
+                     .dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+    Type newElemTy = reduceInnermostDim(newTy);
+    auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs);
+    auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs);
+    auto lowSum = rewriter.create<arith::AddUICarryOp>(loc, lhsElem0, rhsElem0);
+    Value carryVal =
+        rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getCarry());
+    Value high0 = rewriter.create<arith::AddIOp>(loc, carryVal, lhsElem1);
+    Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
 // Pass Definition
@@ -116,12 +249,12 @@ struct EmulateWideIntPass final
     target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
       return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
-    target.addDynamicallyLegalOp<
-        // `func.*` ops
-        func::CallOp, func::ReturnOp,
-        // `arith.*` ops
-        arith::ConstantOp>(
-        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+    auto opLegalCallback = [&typeConverter](Operation *op) {
+      return typeConverter.isLegal(op);
+    };
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+    target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
+                                      vector::VectorDialect>(opLegalCallback);
     RewritePatternSet patterns(ctx);
     arith::populateWideIntEmulationPatterns(typeConverter, patterns);
@@ -201,5 +334,6 @@ void arith::populateWideIntEmulationPatterns(
   populateReturnOpTypeConversionPattern(patterns, typeConverter);
   // Populate `arith.*` conversion patterns.
-  patterns.add<ConvertConstant>(typeConverter, patterns.getContext());
+  patterns.add<ConvertConstant, ConvertAddI>(typeConverter,
+                                             patterns.getContext());

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index 3a1a6c7b2d1f9..472417681b58a 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -75,3 +75,39 @@ func.func @constant_vector() -> vector<3xi64> {
     %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64>
     return %c0 : vector<3xi64>
+// CHECK-LABEL: func @addi_scalar_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[CB]] : i1 to i32
+// CHECK-NEXT:    [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32
+// CHECK-NEXT:    [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[SUM_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[SUM_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+    %x = arith.addi %a, %b : i64
+    return %x : i64
+// CHECK-LABEL: func @addi_vector_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1>
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[CB]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4x1xi32>
+// CHECK-NEXT:    [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK:         [[INS0:%.+]]   = vector.insert_strided_slice [[SUM_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert_strided_slice [[SUM_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<4x2xi32>
+func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+    %x = arith.addi %a, %b : vector<4xi64>
+    return %x : vector<4xi64>


