[Mlir-commits] [mlir] 95c3e9c - [mlir][arith] Support wide int shrui emulation
Jakub Kuderski
llvmlistbot at llvm.org
Fri Sep 16 09:10:40 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-16T12:09:33-04:00
New Revision: 95c3e9c222121bca4b12aeaede03824edc8a33db
URL: https://github.com/llvm/llvm-project/commit/95c3e9c222121bca4b12aeaede03824edc8a33db
DIFF: https://github.com/llvm/llvm-project/commit/95c3e9c222121bca4b12aeaede03824edc8a33db.diff
LOG: [mlir][arith] Support wide int shrui emulation
Tested by checking all 16-bit LHS and all valid RHS when emulating i16 with i8 operations.
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D133722
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index dfa2779f2f38..e0f000af2809 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -28,10 +28,10 @@ using namespace mlir;
// Common Helper Functions
//===----------------------------------------------------------------------===//
-// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
-// Treats `value` as a 2*N bits-wide integer.
-// The bottom bits are returned in the first pair element, while the top bits in
-// the second one.
+/// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
+/// Treats `value` as a 2*N bits-wide integer.
+/// The bottom bits are returned in the first pair element, while the top bits
+/// in the second one.
static std::pair<APInt, APInt> getHalves(const APInt &value,
unsigned newBitWidth) {
APInt low = value.extractBits(newBitWidth, 0);
@@ -39,11 +39,11 @@ static std::pair<APInt, APInt> getHalves(const APInt &value,
return {std::move(low), std::move(high)};
}
-// Returns the type with the last (innermost) dimention reduced to x1.
-// Scalarizes 1D vector inputs to match how we extract/insert vector values,
-// e.g.:
-// - vector<3x2xi16> --> vector<3x1xi16>
-// - vector<2xi16> --> i16
+/// Returns the type with the last (innermost) dimention reduced to x1.
+/// Scalarizes 1D vector inputs to match how we extract/insert vector values,
+/// e.g.:
+/// - vector<3x2xi16> --> vector<3x1xi16>
+/// - vector<2xi16> --> i16
static Type reduceInnermostDim(VectorType type) {
if (type.getShape().size() == 1)
return type.getElementType();
@@ -53,7 +53,7 @@ static Type reduceInnermostDim(VectorType type) {
return VectorType::get(newShape, type.getElementType());
}
-// Returns a constant of integer of vector type filled with (repeated) `value`.
+/// Returns a constant of integer of vector type filled with (repeated) `value`.
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
@@ -68,7 +68,7 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
-// Returns a constant of integer of vector type filled with (repeated) `value`.
+/// Returns a constant of integer of vector type filled with (repeated) `value`.
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
int64_t value) {
@@ -82,11 +82,11 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
APInt(elementBitWidth, value));
}
-// Extracts the `input` vector slice with elements at the last dimension offset
-// by `lastOffset`. Returns a value of vector type with the last dimension
-// reduced to x1 or fully scalarized, e.g.:
-// - vector<3x2xi16> --> vector<3x1xi16>
-// - vector<2xi16> --> i16
+/// Extracts the `input` vector slice with elements at the last dimension offset
+/// by `lastOffset`. Returns a value of vector type with the last dimension
+/// reduced to x1 or fully scalarized, e.g.:
+/// - vector<3x2xi16> --> vector<3x1xi16>
+/// - vector<2xi16> --> i16
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t lastOffset) {
@@ -107,8 +107,8 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
sizes, strides);
}
-// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
-// with the first element at offset 0 and the second element at offset 1.
+/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
+/// with the first element at offset 0 and the second element at offset 1.
static std::pair<Value, Value>
extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
@@ -133,8 +133,8 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
}
-// Performs a vector shape cast to append an x1 dimension. If the
-// `input` is a scalar, this is a noop.
+/// Performs a vector shape cast to append an x1 dimension. If the
+/// `input` is a scalar, this is a noop.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
auto vecTy = input.getType().dyn_cast<VectorType>();
@@ -148,9 +148,9 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
}
-// Inserts the `source` vector slice into the `dest` vector at offset
-// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a
-// 1D vector.
+/// Inserts the `source` vector slice into the `dest` vector at offset
+/// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
+/// a 1D vector.
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value source, Value dest,
int64_t lastOffset) {
@@ -168,12 +168,12 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
offsets, strides);
}
-// Constructs a new vector of type `resultType` by creating a series of
-// insertions of `resultComponents`, each at the next offset of the last vector
-// dimension.
-// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
-// when `resultComponents` are `vector<...x1xT>`s, the result type is
-// `vector<...xNxT>`, where `N` is the number of `resultComponenets`.
+/// Constructs a new vector of type `resultType` by creating a series of
+/// insertions of `resultComponents`, each at the next offset of the last vector
+/// dimension.
+/// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
+/// when `resultComponents` are `vector<...x1xT>`s, the result type is
+/// `vector<...xNxT>`, where `N` is the number of `resultComponenets`.
static Value constructResultVector(ConversionPatternRewriter &rewriter,
Location loc, VectorType resultType,
ValueRange resultComponents) {
@@ -451,6 +451,90 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertShRUI
+//===----------------------------------------------------------------------===//
+
+struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+
+ Type oldTy = op.getType();
+ auto newTy = getTypeConverter()->convertType(oldTy).cast<VectorType>();
+ Type newOperandTy = reduceInnermostDim(newTy);
+ unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
+
+ // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
+ // high halves of the results separately:
+ // 1. low := a or b or c, where:
+ // a) Bits from LHS.low, shifted by the RHS.
+ // b) Bits from LHS.high, shifted left. These matter when
+ // RHS < newBitWidth, e.g.:
+ // [hhhh][0000] shrui 3 --> [000h][hhh0]
+ // ^
+ // |
+ // [hhhh] shli (4 - 1)
+ // c) Bits from LHS.high, shifted right. These come into play when
+ // RHS > newBitWidth, e.g.:
+ // [hhhh][0000] shrui 7 --> [0000][000h]
+ // ^
+ // |
+ // [hhhh] shrui (7 - 4)
+ //
+ // 2. high := LHS.high shrui RHS
+ //
+ // Because shifts by values >= newBitWidth are undefined, we ignore the high
+ // half of RHS, and introduce 'bounds checks' to account for
+ // RHS.low > newBitWidth.
+ //
+ // TODO: Explore possible optimizations.
+ Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
+ Value elemBitWidth =
+ createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
+
+ Value illegalElemShift = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+
+ Value shiftedElem0 =
+ rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
+ Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
+ zeroCst, shiftedElem0);
+ Value shiftedElem1 =
+ rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
+ Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
+ zeroCst, shiftedElem1);
+
+ Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
+ loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value leftShiftAmount =
+ rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ Value shiftedLeft =
+ rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
+ Value overshotShiftAmount =
+ rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ Value shiftedRight =
+ rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
+
+ Value resElem0High = rewriter.create<arith::SelectOp>(
+ loc, illegalElemShift, shiftedRight, shiftedLeft);
+ Value resElem0 =
+ rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
+
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertTruncI
//===----------------------------------------------------------------------===//
@@ -607,7 +691,7 @@ void arith::populateWideIntEmulationPatterns(
// Misc ops.
ConvertConstant, ConvertVectorPrint,
// Binary ops.
- ConvertAddI, ConvertMulI,
+ ConvertAddI, ConvertMulI, ConvertShRUI,
// Extension and truncation ops.
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
patterns.getContext());
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index 2b73f4a937ca..6445ea44cada 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -259,3 +259,57 @@ func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
%m = arith.muli %a, %b : vector<3xi64>
return %m : vector<3xi64>
}
+
+// CHECK-LABEL: func.func @shrui_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[CST32:%.+]] = arith.constant 32 : i32
+// CHECK-DAG: [[OOB:%.+]] = arith.cmpi uge, [[LOW1]], [[CST32]] : i32
+// CHECK-DAG: [[SHLOW0:%.+]] = arith.shrui [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES0LOW:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLOW0]] : i32
+// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[HIGH0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RESLOW1:%.+]] = arith.select [[OOB]], [[CST0]], [[SHRHIGH0]] : i32
+// CHECK-NEXT: [[SHAMT:%.+]] = arith.select [[OOB]], [[CST32]], [[LOW1]] : i32
+// CHECK-NEXT: [[LSHAMT:%.+]] = arith.subi [[CST32]], [[SHAMT]] : i32
+// CHECK-NEXT: [[SHLHIGH0:%.+]] = arith.shli [[HIGH0]], [[LSHAMT]] : i32
+// CHECK-NEXT: [[RSHAMT:%.+]] = arith.subi [[LOW1]], [[CST32]] : i32
+// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[HIGH0]], [[RSHAMT]] : i32
+// CHECK-NEXT: [[RES0HIGH:%.+]] = arith.select [[OOB]], [[SHRHIGH0]], [[SHLHIGH0]] : i32
+// CHECK-NEXT: [[RES0:%.+]] = arith.ori [[RES0LOW]], [[RES0HIGH]] : i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RESLOW1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @shrui_scalar(%a : i64, %b : i64) -> i64 {
+ %c = arith.shrui %a, %b : i64
+ return %c : i64
+}
+
+// CHECK-LABEL: func.func @shrui_scalar_cst_2
+// CHECK-SAME: ({{%.+}}: vector<2xi32>) -> vector<2xi32>
+// CHECK: return {{%.+}} : vector<2xi32>
+func.func @shrui_scalar_cst_2(%a : i64) -> i64 {
+ %b = arith.constant 2 : i64
+ %c = arith.shrui %a, %b : i64
+ return %c : i64
+}
+
+// CHECK-LABEL: func.func @shrui_scalar_cst_36
+// CHECK-SAME: ({{%.+}}: vector<2xi32>) -> vector<2xi32>
+// CHECK: return {{%.+}} : vector<2xi32>
+func.func @shrui_scalar_cst_36(%a : i64) -> i64 {
+ %b = arith.constant 36 : i64
+ %c = arith.shrui %a, %b : i64
+ return %c : i64
+}
+
+// CHECK-LABEL: func.func @shrui_vector
+// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %m = arith.shrui %a, %b : vector<3xi64>
+ return %m : vector<3xi64>
+}
More information about the Mlir-commits
mailing list