[Mlir-commits] [mlir] 3f36d2d - [mlir][arith] Simplify muli emulation with mului_extended

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 12 07:58:46 PST 2022


Author: Jakub Kuderski
Date: 2022-12-12T10:58:07-05:00
New Revision: 3f36d2d579d8b0e8824d9dd99bfa79f456858f88

URL: https://github.com/llvm/llvm-project/commit/3f36d2d579d8b0e8824d9dd99bfa79f456858f88
DIFF: https://github.com/llvm/llvm-project/commit/3f36d2d579d8b0e8824d9dd99bfa79f456858f88.diff

LOG: [mlir][arith] Simplify muli emulation with mului_extended

Using `arith.mului_extended` makes it much simpler to emulate wide
integer multiplication.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139776

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
    mlir/test/Dialect/Arith/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index f10fefba87eef..1fffb0312bb1d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -419,78 +419,26 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
-    Type newElemTy = reduceInnermostDim(newTy);
-    unsigned newBitWidth = newTy.getElementTypeBitWidth();
-    unsigned digitBitWidth = newBitWidth / 2;
-
     auto [lhsElem0, lhsElem1] =
         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
     auto [rhsElem0, rhsElem1] =
         extractLastDimHalves(rewriter, loc, adaptor.getRhs());
 
-    // Emulate multiplication by splitting each input element of type i2N into 4
-    // digits of type iN and bit width i(N/2). This is so that the intermediate
-    // multiplications and additions do not overflow. We extract these i(N/2)
-    // digits from iN vector elements by masking (low digit) and shifting right
-    // (high digit).
-    //
     // The multiplication algorithm used is the standard (long) multiplication.
-    // Multiplying two i2N integers produces (at most) a i4N result, but because
-    // the calculation of top i2N is not necessary, we omit it.
-    // In total, this implementations performs 10 intermediate multiplications
-    // and 16 additions. The number of multiplications could be decreased by
-    // switching to a more efficient algorithm like Karatsuba. This would,
-    // however, require being able to perform (intermediate) wide additions and
-    // subtractions, so it is not clear that such implementation would be more
-    // efficient.
-
-    APInt lowMaskVal(newBitWidth, 1);
-    lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1;
-    Value lowMask =
-        createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal);
-    auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
-      return rewriter.create<arith::AndIOp>(loc, newElemTy, v, lowMask);
-    };
+    // Multiplying two i2N integers produces (at most) an i4N result, but
+    // because the calculation of top i2N is not necessary, we omit it.
+    auto mulLowLow =
+        rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
+    Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
+    Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
+
+    Value resLow = mulLowLow.getLow();
+    Value resHi =
+        rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
+    resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
 
-    Value shiftVal =
-        createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth);
-    auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
-      return rewriter.create<arith::ShRUIOp>(loc, v, shiftVal);
-    };
-
-    Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0);
-    std::array<Value, 4> resultDigits = {zeroDigit, zeroDigit, zeroDigit,
-                                         zeroDigit};
-    std::array<Value, 4> lhsDigits = {
-        getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1),
-        getHighDigit(lhsElem1)};
-    std::array<Value, 4> rhsDigits = {
-        getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1),
-        getHighDigit(rhsElem1)};
-
-    for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) {
-      for (unsigned j = 0; i + j != e; ++j) {
-        Value mul =
-            rewriter.create<arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
-        Value current =
-            rewriter.createOrFold<arith::AddIOp>(loc, resultDigits[i + j], mul);
-        resultDigits[i + j] = getLowDigit(current);
-        if (i + j + 1 != e) {
-          Value carry = rewriter.createOrFold<arith::AddIOp>(
-              loc, resultDigits[i + j + 1], getHighDigit(current));
-          resultDigits[i + j + 1] = carry;
-        }
-      }
-    }
-
-    auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
-      Value highBits = rewriter.create<arith::ShLIOp>(loc, high, shiftVal);
-      return rewriter.create<arith::OrIOp>(loc, low, highBits);
-    };
-    Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]);
-    Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]);
     Value resultVec =
-        constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1});
+        constructResultVector(rewriter, loc, newTy, {resLow, resHi});
     rewriter.replaceOp(op, resultVec);
     return success();
   }

diff  --git a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
index 7a6657634b3ea..ddf5549f27e29 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir
@@ -9,8 +9,12 @@
 // CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi512>
 // CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi512>
 //
-// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1.
-// CHECK-NEXT:    {{.+}}         = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512
+// CHECK-DAG:     arith.mului_extended
+// CHECK-DAG:     arith.muli
+// CHECK-DAG:     arith.muli
+// CHECK-NEXT:    arith.addi
+// CHECK-NEXT:    arith.addi
+//
 // CHECK:         return {{%.+}} : vector<2xi512>
 func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 {
     %m = arith.muli %a, %b : i1024

diff  --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ab47a56dce092..80edc6f2ad001 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -661,44 +661,20 @@ func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c
 
 // CHECK-LABEL: func.func @muli_scalar
 // CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
-// CHECK-NEXT:    [[LOW0:%.+]]      = vector.extract [[ARG0]][0] : vector<2xi32>
-// CHECK-NEXT:    [[HIGH0:%.+]]     = vector.extract [[ARG0]][1] : vector<2xi32>
-// CHECK-NEXT:    [[LOW1:%.+]]      = vector.extract [[ARG1]][0] : vector<2xi32>
-// CHECK-NEXT:    [[HIGH1:%.+]]     = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]  = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]  = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
 //
-// CHECK-DAG:     [[MASK:%.+]]      = arith.constant 65535 : i32
-// CHECK-DAG:     [[C16:%.+]]       = arith.constant 16 : i32
+// CHECK-DAG:     [[RESLOW:%.+]], [[HI0:%.+]] = arith.mului_extended [[LOW0]], [[LOW1]] : i32
+// CHECK-DAG:     [[HI1:%.+]]                 = arith.muli [[LOW0]], [[HIGH1]] : i32
+// CHECK-DAG:     [[HI2:%.+]]                 = arith.muli [[HIGH0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RESHI1:%.+]]              = arith.addi [[HI0]], [[HI1]] : i32
+// CHECK-NEXT:    [[RESHI2:%.+]]              = arith.addi [[RESHI1]], [[HI2]] : i32
 //
-// CHECK:         [[LOWLOW0:%.+]]   = arith.andi [[LOW0]], [[MASK]] : i32
-// CHECK-NEXT:    [[HIGHLOW0:%.+]]  = arith.shrui [[LOW0]], [[C16]] : i32
-// CHECK-NEXT:    [[LOWHIGH0:%.+]]  = arith.andi [[HIGH0]], [[MASK]] : i32
-// CHECK-NEXT:    [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32
-// CHECK-NEXT:    [[LOWLOW1:%.+]]   = arith.andi [[LOW1]], [[MASK]] : i32
-// CHECK-NEXT:    [[HIGHLOW1:%.+]]  = arith.shrui [[LOW1]], [[C16]] : i32
-// CHECK-NEXT:    [[LOWHIGH1:%.+]]  = arith.andi [[HIGH1]], [[MASK]] : i32
-// CHECK-NEXT:    [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32
-//
-// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32
-// CHECK-DAG      {{%.+}}           = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32
-// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32
-// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32
-//
-// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32
-// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32
-// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32
-//
-// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32
-// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32
-//
-// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32
-//
-// CHECK:         [[RESHIGH0:%.+]]  = arith.shli {{%.+}}, [[C16]] : i32
-// CHECK-NEXT:    [[RES0:%.+]]      = arith.ori {{%.+}}, [[RESHIGH0]] : i32
-// CHECK-NEXT:    [[RESHIGH1:%.+]]  = arith.shli {{%.+}}, [[C16]] : i32
-// CHECK-NEXT:    [[RES1:%.+]]      = arith.ori {{%.+}}, [[RESHIGH1]] : i32
-// CHECK-NEXT:    [[VZ:%.+]]        = arith.constant dense<0> : vector<2xi32>
-// CHECK-NEXT:    [[INS0:%.+]]      = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
-// CHECK-NEXT:    [[INS1:%.+]]      = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[VZ:%.+]]   = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]] = vector.insert [[RESLOW]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]] = vector.insert [[RESHI2]], [[INS0]] [1] : i32 into vector<2xi32>
 // CHECK-NEXT:    return [[INS1]] : vector<2xi32>
 func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
     %m = arith.muli %a, %b : i64
@@ -707,6 +683,11 @@ func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
 
 // CHECK-LABEL: func.func @muli_vector
 // CHECK-SAME:    ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK-DAG:     arith.mului_extended
+// CHECK-DAG:     arith.muli
+// CHECK-DAG:     arith.muli
+// CHECK-NEXT:    arith.addi
+// CHECK-NEXT:    arith.addi
 // CHECK:       return {{%.+}} : vector<3x2xi32>
 func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
     %m = arith.muli %a, %b : vector<3xi64>


        


More information about the Mlir-commits mailing list