[Mlir-commits] [mlir] c7f6461 - [mlir][arith] Support wide integer addition emulation
Jakub Kuderski
llvmlistbot at llvm.org
Fri Sep 9 13:50:39 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-09T16:49:38-04:00
New Revision: c7f64616e9e1f66059bf768324a440bfde98132c
URL: https://github.com/llvm/llvm-project/commit/c7f64616e9e1f66059bf768324a440bfde98132c
DIFF: https://github.com/llvm/llvm-project/commit/c7f64616e9e1f66059bf768324a440bfde98132c.diff
LOG: [mlir][arith] Support wide integer addition emulation
I tested this implementation for all i16 input pairs, when emulating i16
operations with i8 operations.
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D133137
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index 2d24513e27344..2bb0de30d7a83 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -24,6 +24,10 @@ namespace mlir::arith {
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Common Helper Functions
+//===----------------------------------------------------------------------===//
+
// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
// Treats `value` as a 2*N bits-wide integer.
// The bottom bits are returned in the first pair element, while the top bits in
@@ -35,6 +39,96 @@ static std::pair<APInt, APInt> getHalves(const APInt &value,
return {std::move(low), std::move(high)};
}
+// Returns the type with the last (innermost) dimention reduced to x1.
+// Scalarizes 1D vector inputs to match how we extract/insert vector values,
+// e.g.:
+// - vector<3x2xi16> --> vector<3x1xi16>
+// - vector<2xi16> --> i16
+static Type reduceInnermostDim(VectorType type) {
+ if (type.getShape().size() == 1)
+ return type.getElementType();
+
+ auto newShape = to_vector(type.getShape());
+ newShape.back() = 1;
+ return VectorType::get(newShape, type.getElementType());
+}
+
+// Extracts the `input` vector slice with elements at the last dimension offset
+// by `lastOffset`. Returns a value of vector type with the last dimension
+// reduced to x1 or fully scalarized, e.g.:
+// - vector<3x2xi16> --> vector<3x1xi16>
+// - vector<2xi16> --> i16
+static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ int64_t lastOffset) {
+ llvm::ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+ assert(lastOffset < shape.back() && "Offset out of bounds");
+
+ // Scalarize the result in case of 1D vectors.
+ if (shape.size() == 1)
+ return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
+
+ SmallVector<int64_t> offsets(shape.size(), 0);
+ offsets.back() = lastOffset;
+ auto sizes = llvm::to_vector(shape);
+ sizes.back() = 1;
+ SmallVector<int64_t> strides(shape.size(), 1);
+
+ return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
+ sizes, strides);
+}
+
+// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
+// with the first element at offset 0 and the second element at offset 1.
+static std::pair<Value, Value>
+extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
+ Value input) {
+ return {extractLastDimSlice(rewriter, loc, input, 0),
+ extractLastDimSlice(rewriter, loc, input, 1)};
+}
+
+// Inserts the `source` vector slice into the `dest` vector at offset
+// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a
+// 1D vector.
+static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
+ Location loc, Value source, Value dest,
+ int64_t lastOffset) {
+ llvm::ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+ assert(lastOffset < shape.back() && "Offset out of bounds");
+
+ // Handle scalar source.
+ if (source.getType().isa<IntegerType>())
+ return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
+
+ SmallVector<int64_t> offsets(shape.size(), 0);
+ offsets.back() = lastOffset;
+ SmallVector<int64_t> strides(shape.size(), 1);
+ return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
+ offsets, strides);
+}
+
+// Constructs a new vector of type `resultType` by creating a series of
+// insertions of `resultComponents`, each at the next offset of the last vector
+// dimension.
+// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
+// when `resultComponents` are `vector<...x1xT>`s, the result type is
+// `vector<...xNxT>`, where `N` is the number of `resultComponenets`.
+static Value constructResultVector(ConversionPatternRewriter &rewriter,
+ Location loc, VectorType resultType,
+ ValueRange resultComponents) {
+ llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
+ assert(!resultShape.empty() && "Result expected to have dimentions");
+ assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
+ "Wrong number of result components");
+
+ Value resultVec =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+ for (auto [i, component] : llvm::enumerate(resultComponents))
+ resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
+
+ return resultVec;
+}
+
namespace {
//===----------------------------------------------------------------------===//
// ConvertConstant
@@ -94,6 +188,45 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertAddI
+//===----------------------------------------------------------------------===//
+
+struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+
+ Value lhs = adaptor.getLhs();
+ Value rhs = adaptor.getRhs();
+ auto newTy = getTypeConverter()
+ ->convertType(op.getType())
+ .dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+
+ Type newElemTy = reduceInnermostDim(newTy);
+
+ auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs);
+ auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs);
+
+ auto lowSum = rewriter.create<arith::AddUICarryOp>(loc, lhsElem0, rhsElem0);
+ Value carryVal =
+ rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getCarry());
+
+ Value high0 = rewriter.create<arith::AddIOp>(loc, carryVal, lhsElem1);
+ Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
+
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -116,12 +249,12 @@ struct EmulateWideIntPass final
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
});
- target.addDynamicallyLegalOp<
- // `func.*` ops
- func::CallOp, func::ReturnOp,
- // `arith.*` ops
- arith::ConstantOp>(
- [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+ auto opLegalCallback = [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op);
+ };
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+ target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
+ vector::VectorDialect>(opLegalCallback);
RewritePatternSet patterns(ctx);
arith::populateWideIntEmulationPatterns(typeConverter, patterns);
@@ -201,5 +334,6 @@ void arith::populateWideIntEmulationPatterns(
populateReturnOpTypeConversionPattern(patterns, typeConverter);
// Populate `arith.*` conversion patterns.
- patterns.add<ConvertConstant>(typeConverter, patterns.getContext());
+ patterns.add<ConvertConstant, ConvertAddI>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index 3a1a6c7b2d1f9..472417681b58a 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -75,3 +75,39 @@ func.func @constant_vector() -> vector<3xi64> {
%c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64>
return %c0 : vector<3xi64>
}
+
+// CHECK-LABEL: func @addi_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : i1 to i32
+// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32
+// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[SUM_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUM_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.addi %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @addi_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1>
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4x1xi32>
+// CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUM_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUM_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<4x2xi32>
+func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+ %x = arith.addi %a, %b : vector<4xi64>
+ return %x : vector<4xi64>
+}
More information about the Mlir-commits
mailing list