[Mlir-commits] [mlir] 3f36d2d - [mlir][arith] Simplify muli emulation with mului_extended
Jakub Kuderski
llvmlistbot at llvm.org
Mon Dec 12 07:58:46 PST 2022
Author: Jakub Kuderski
Date: 2022-12-12T10:58:07-05:00
New Revision: 3f36d2d579d8b0e8824d9dd99bfa79f456858f88
URL: https://github.com/llvm/llvm-project/commit/3f36d2d579d8b0e8824d9dd99bfa79f456858f88
DIFF: https://github.com/llvm/llvm-project/commit/3f36d2d579d8b0e8824d9dd99bfa79f456858f88.diff
LOG: [mlir][arith] Simplify muli emulation with mului_extended
Using `arith.mului_extended` makes it much simpler to emulate wide
integer multiplication.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139776
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
mlir/test/Dialect/Arith/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index f10fefba87eef..1fffb0312bb1d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -419,78 +419,26 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", op.getType()));
- Type newElemTy = reduceInnermostDim(newTy);
- unsigned newBitWidth = newTy.getElementTypeBitWidth();
- unsigned digitBitWidth = newBitWidth / 2;
-
auto [lhsElem0, lhsElem1] =
extractLastDimHalves(rewriter, loc, adaptor.getLhs());
auto [rhsElem0, rhsElem1] =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
- // Emulate multiplication by splitting each input element of type i2N into 4
- // digits of type iN and bit width i(N/2). This is so that the intermediate
- // multiplications and additions do not overflow. We extract these i(N/2)
- // digits from iN vector elements by masking (low digit) and shifting right
- // (high digit).
- //
// The multiplication algorithm used is the standard (long) multiplication.
- // Multiplying two i2N integers produces (at most) a i4N result, but because
- // the calculation of top i2N is not necessary, we omit it.
- // In total, this implementations performs 10 intermediate multiplications
- // and 16 additions. The number of multiplications could be decreased by
- // switching to a more efficient algorithm like Karatsuba. This would,
- // however, require being able to perform (intermediate) wide additions and
- // subtractions, so it is not clear that such implementation would be more
- // efficient.
-
- APInt lowMaskVal(newBitWidth, 1);
- lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1;
- Value lowMask =
- createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal);
- auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
- return rewriter.create<arith::AndIOp>(loc, newElemTy, v, lowMask);
- };
+ // Multiplying two i2N integers produces (at most) an i4N result, but
+ // because the calculation of top i2N is not necessary, we omit it.
+ auto mulLowLow =
+ rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
+ Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
+ Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
+
+ Value resLow = mulLowLow.getLow();
+ Value resHi =
+ rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
+ resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
- Value shiftVal =
- createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth);
- auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
- return rewriter.create<arith::ShRUIOp>(loc, v, shiftVal);
- };
-
- Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0);
- std::array<Value, 4> resultDigits = {zeroDigit, zeroDigit, zeroDigit,
- zeroDigit};
- std::array<Value, 4> lhsDigits = {
- getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1),
- getHighDigit(lhsElem1)};
- std::array<Value, 4> rhsDigits = {
- getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1),
- getHighDigit(rhsElem1)};
-
- for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) {
- for (unsigned j = 0; i + j != e; ++j) {
- Value mul =
- rewriter.create<arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
- Value current =
- rewriter.createOrFold<arith::AddIOp>(loc, resultDigits[i + j], mul);
- resultDigits[i + j] = getLowDigit(current);
- if (i + j + 1 != e) {
- Value carry = rewriter.createOrFold<arith::AddIOp>(
- loc, resultDigits[i + j + 1], getHighDigit(current));
- resultDigits[i + j + 1] = carry;
- }
- }
- }
-
- auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
- Value highBits = rewriter.create<arith::ShLIOp>(loc, high, shiftVal);
- return rewriter.create<arith::OrIOp>(loc, low, highBits);
- };
- Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]);
- Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]);
Value resultVec =
- constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1});
+ constructResultVector(rewriter, loc, newTy, {resLow, resHi});
rewriter.replaceOp(op, resultVec);
return success();
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
index 7a6657634b3ea..ddf5549f27e29 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
@@ -9,8 +9,12 @@
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi512>
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi512>
//
-// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1.
-// CHECK-NEXT: {{.+}} = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512
+// CHECK-DAG: arith.mului_extended
+// CHECK-DAG: arith.muli
+// CHECK-DAG: arith.muli
+// CHECK-NEXT: arith.addi
+// CHECK-NEXT: arith.addi
+//
// CHECK: return {{%.+}} : vector<2xi512>
func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 {
%m = arith.muli %a, %b : i1024
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ab47a56dce092..80edc6f2ad001 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -661,44 +661,20 @@ func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c
// CHECK-LABEL: func.func @muli_scalar
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
-// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
-// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
-// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
-// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
//
-// CHECK-DAG: [[MASK:%.+]] = arith.constant 65535 : i32
-// CHECK-DAG: [[C16:%.+]] = arith.constant 16 : i32
+// CHECK-DAG: [[RESLOW:%.+]], [[HI0:%.+]] = arith.mului_extended [[LOW0]], [[LOW1]] : i32
+// CHECK-DAG: [[HI1:%.+]] = arith.muli [[LOW0]], [[HIGH1]] : i32
+// CHECK-DAG: [[HI2:%.+]] = arith.muli [[HIGH0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RESHI1:%.+]] = arith.addi [[HI0]], [[HI1]] : i32
+// CHECK-NEXT: [[RESHI2:%.+]] = arith.addi [[RESHI1]], [[HI2]] : i32
//
-// CHECK: [[LOWLOW0:%.+]] = arith.andi [[LOW0]], [[MASK]] : i32
-// CHECK-NEXT: [[HIGHLOW0:%.+]] = arith.shrui [[LOW0]], [[C16]] : i32
-// CHECK-NEXT: [[LOWHIGH0:%.+]] = arith.andi [[HIGH0]], [[MASK]] : i32
-// CHECK-NEXT: [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32
-// CHECK-NEXT: [[LOWLOW1:%.+]] = arith.andi [[LOW1]], [[MASK]] : i32
-// CHECK-NEXT: [[HIGHLOW1:%.+]] = arith.shrui [[LOW1]], [[C16]] : i32
-// CHECK-NEXT: [[LOWHIGH1:%.+]] = arith.andi [[HIGH1]], [[MASK]] : i32
-// CHECK-NEXT: [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32
-//
-// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32
-// CHECK-DAG {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32
-// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32
-// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32
-//
-// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32
-// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32
-// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32
-//
-// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32
-// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32
-//
-// CHECK-DAG: {{%.+}} = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32
-//
-// CHECK: [[RESHIGH0:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
-// CHECK-NEXT: [[RES0:%.+]] = arith.ori {{%.+}}, [[RESHIGH0]] : i32
-// CHECK-NEXT: [[RESHIGH1:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
-// CHECK-NEXT: [[RES1:%.+]] = arith.ori {{%.+}}, [[RESHIGH1]] : i32
-// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
-// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
-// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RESLOW]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RESHI2]], [[INS0]] [1] : i32 into vector<2xi32>
// CHECK-NEXT: return [[INS1]] : vector<2xi32>
func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
%m = arith.muli %a, %b : i64
@@ -707,6 +683,11 @@ func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
// CHECK-LABEL: func.func @muli_vector
// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK-DAG: arith.mului_extended
+// CHECK-DAG: arith.muli
+// CHECK-DAG: arith.muli
+// CHECK-NEXT: arith.addi
+// CHECK-NEXT: arith.addi
// CHECK: return {{%.+}} : vector<3x2xi32>
func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
%m = arith.muli %a, %b : vector<3xi64>
More information about the Mlir-commits
mailing list