# [Mlir-commits] [mlir] db0d6e5 - [mlir][arith] Support wide integer multiplication emulation

Jakub Kuderski llvmlistbot at llvm.org
Fri Sep 16 09:04:17 PDT 2022

```Author: Jakub Kuderski
Date: 2022-09-16T12:02:25-04:00
New Revision: db0d6e567df3d34584be349347e357123246759d

URL: https://github.com/llvm/llvm-project/commit/db0d6e567df3d34584be349347e357123246759d
DIFF: https://github.com/llvm/llvm-project/commit/db0d6e567df3d34584be349347e357123246759d.diff

LOG: [mlir][arith] Support wide integer multiplication emulation

Emulate multiplication by splitting each input element of type i2N into 4
digits of type iN and bit width i(N/2). This is so that the intermediate
multiplications and additions do not overflow. We extract these i(N/2)
digits from iN vector elements by masking (low digit) and shifting right
(high digit).

The multiplication algorithm used is the standard (long) multiplication.
Multiplying two i2N integers produces (at most) a i4N result, but because
the calculation of top i2N is not necessary, we omit it.
In total, this implementations performs 10 intermediate multiplications
and 16 additions. The number of multiplications could be decreased by
switching to a more efficient algorithm like Karatsuba. This would,
however, require being able to perform (intermediate) wide additions and
subtractions, so it is not clear that such implementation would be more
efficient.

I tested this on all 16-bit inut pairs, when emulating i16 with i8.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D133629

Added:
mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir
mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir

Modified:
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir

Removed:

################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index 1c4189c27c2df..dfa2779f2f38f 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -53,6 +53,35 @@ static Type reduceInnermostDim(VectorType type) {
return VectorType::get(newShape, type.getElementType());
}

+// Returns a constant of integer of vector type filled with (repeated) `value`.
+static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
+                                         Location loc, Type type,
+                                         const APInt &value) {
+  Attribute attr;
+  if (auto intTy = type.dyn_cast<IntegerType>()) {
+    attr = rewriter.getIntegerAttr(type, value);
+  } else {
+    auto vecTy = type.cast<VectorType>();
+    attr = SplatElementsAttr::get(vecTy, value);
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
+// Returns a constant of integer of vector type filled with (repeated) `value`.
+static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
+                                         Location loc, Type type,
+                                         int64_t value) {
+  unsigned elementBitWidth = 0;
+  if (auto intTy = type.dyn_cast<IntegerType>())
+    elementBitWidth = intTy.getWidth();
+  else
+    elementBitWidth = type.cast<VectorType>().getElementTypeBitWidth();
+
+  return createScalarOrSplatConstant(rewriter, loc, type,
+                                     APInt(elementBitWidth, value));
+}
+
// Extracts the `input` vector slice with elements at the last dimension offset
// by `lastOffset`. Returns a value of vector type with the last dimension
// reduced to x1 or fully scalarized, e.g.:
@@ -154,8 +183,7 @@ static Value constructResultVector(ConversionPatternRewriter &rewriter,
assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
"Wrong number of result components");

-  Value resultVec =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
for (auto [i, component] : llvm::enumerate(resultComponents))
resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);

@@ -232,9 +260,6 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
-
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
auto newTy = getTypeConverter()
->convertType(op.getType())
.dyn_cast_or_null<VectorType>();
@@ -243,8 +268,10 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {

Type newElemTy = reduceInnermostDim(newTy);

-    auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs);
-    auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs);
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    auto [rhsElem0, rhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getRhs());

auto lowSum = rewriter.create<arith::AddUICarryOp>(loc, lhsElem0, rhsElem0);
Value carryVal =
@@ -260,6 +287,100 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
}
};

+//===----------------------------------------------------------------------===//
+// ConvertMulI
+//===----------------------------------------------------------------------===//
+
+struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()
+                     ->convertType(op.getType())
+                     .dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+
+    Type newElemTy = reduceInnermostDim(newTy);
+    unsigned newBitWidth = newTy.getElementTypeBitWidth();
+    unsigned digitBitWidth = newBitWidth / 2;
+
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    auto [rhsElem0, rhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+    // Emulate multiplication by splitting each input element of type i2N into 4
+    // digits of type iN and bit width i(N/2). This is so that the intermediate
+    // multiplications and additions do not overflow. We extract these i(N/2)
+    // digits from iN vector elements by masking (low digit) and shifting right
+    // (high digit).
+    //
+    // The multiplication algorithm used is the standard (long) multiplication.
+    // Multiplying two i2N integers produces (at most) a i4N result, but because
+    // the calculation of top i2N is not necessary, we omit it.
+    // In total, this implementations performs 10 intermediate multiplications
+    // and 16 additions. The number of multiplications could be decreased by
+    // switching to a more efficient algorithm like Karatsuba. This would,
+    // however, require being able to perform (intermediate) wide additions and
+    // subtractions, so it is not clear that such implementation would be more
+    // efficient.
+
+    APInt lowMaskVal(newBitWidth, 1);
+    lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1;
+    Value lowMask =
+        createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal);
+    auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
+      return rewriter.create<arith::AndIOp>(loc, newElemTy, v, lowMask);
+    };
+
+    Value shiftVal =
+        createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth);
+    auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
+      return rewriter.create<arith::ShRUIOp>(loc, v, shiftVal);
+    };
+
+    Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0);
+    std::array<Value, 4> resultDigits = {zeroDigit, zeroDigit, zeroDigit,
+                                         zeroDigit};
+    std::array<Value, 4> lhsDigits = {
+        getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1),
+        getHighDigit(lhsElem1)};
+    std::array<Value, 4> rhsDigits = {
+        getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1),
+        getHighDigit(rhsElem1)};
+
+    for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) {
+      for (unsigned j = 0; i + j != e; ++j) {
+        Value mul =
+            rewriter.create<arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
+        Value current =
+            rewriter.createOrFold<arith::AddIOp>(loc, resultDigits[i + j], mul);
+        resultDigits[i + j] = getLowDigit(current);
+        if (i + j + 1 != e) {
+          Value carry = rewriter.createOrFold<arith::AddIOp>(
+              loc, resultDigits[i + j + 1], getHighDigit(current));
+          resultDigits[i + j + 1] = carry;
+        }
+      }
+    }
+
+    auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
+      Value highBits = rewriter.create<arith::ShLIOp>(loc, high, shiftVal);
+      return rewriter.create<arith::OrIOp>(loc, low, highBits);
+    };
+    Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]);
+    Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]);
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
//===----------------------------------------------------------------------===//
// ConvertExtSI
//===----------------------------------------------------------------------===//
@@ -285,8 +406,8 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
Value extended = rewriter.createOrFold<arith::ExtSIOp>(
loc, newResultComponentTy, newOperand);
-    Value operandZeroCst = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(newResultComponentTy));
+    Value operandZeroCst =
+        createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
Value signBit = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
Value signValue =
@@ -323,8 +444,7 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
Value extended = rewriter.createOrFold<arith::ExtUIOp>(
loc, newResultComponentTy, newOperand);
-    Value zeroCst = rewriter.create<arith::ConstantOp>(
-        op->getLoc(), rewriter.getZeroAttr(newTy));
+    Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
rewriter.replaceOp(op, newRes);
return success();
@@ -384,7 +504,7 @@ struct EmulateWideIntPass final
using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;

void runOnOperation() override {
-    if (!llvm::isPowerOf2_32(widestIntSupported)) {
+    if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
signalPassFailure();
return;
}
@@ -421,7 +541,8 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
unsigned widestIntSupportedByTarget)
: maxIntWidth(widestIntSupportedByTarget) {
assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
-         "Only power-of-two integers are supported");
+         "Only power-of-two integers with are supported");
+  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");

// Scalar case.
addConversion([this](IntegerType ty) -> Optional<Type> {
@@ -486,7 +607,7 @@ void arith::populateWideIntEmulationPatterns(
// Misc ops.
ConvertConstant, ConvertVectorPrint,
// Binary ops.
-      ConvertAddI,
+      ConvertAddI, ConvertMulI,
// Extension and truncation ops.
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
patterns.getContext());

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir
new file mode 100644
index 0000000000000..7a6657634b3ea
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int-very-wide.mlir
@@ -0,0 +1,18 @@
+// Check that emulation of wery wide types (>64 bits) works as expected.
+
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=512" %s | FileCheck %s
+
+// CHECK-LABEL: func.func @muli_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi512>, [[ARG1:%.+]]: vector<2xi512>) -> vector<2xi512>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : vector<2xi512>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : vector<2xi512>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi512>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi512>
+//
+// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1.
+// CHECK-NEXT:    {{.+}}         = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512
+// CHECK:         return {{%.+}} : vector<2xi512>
+func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 {
+    %m = arith.muli %a, %b : i1024
+    return %m : i1024
+}

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index ae4c8126ae192..2b73f4a937cae 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -205,3 +205,57 @@ func.func @trunci_vector(%a : vector<3xi64>) -> vector<3xi16> {
%b = arith.trunci %a : vector<3xi64> to vector<3xi16>
return %b : vector<3xi16>
}
+
+// CHECK-LABEL: func.func @muli_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]      = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]     = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]      = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]     = vector.extract [[ARG1]][1] : vector<2xi32>
+//
+// CHECK-DAG:     [[MASK:%.+]]      = arith.constant 65535 : i32
+// CHECK-DAG:     [[C16:%.+]]       = arith.constant 16 : i32
+//
+// CHECK:         [[LOWLOW0:%.+]]   = arith.andi [[LOW0]], [[MASK]] : i32
+// CHECK-NEXT:    [[HIGHLOW0:%.+]]  = arith.shrui [[LOW0]], [[C16]] : i32
+// CHECK-NEXT:    [[LOWHIGH0:%.+]]  = arith.andi [[HIGH0]], [[MASK]] : i32
+// CHECK-NEXT:    [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32
+// CHECK-NEXT:    [[LOWLOW1:%.+]]   = arith.andi [[LOW1]], [[MASK]] : i32
+// CHECK-NEXT:    [[HIGHLOW1:%.+]]  = arith.shrui [[LOW1]], [[C16]] : i32
+// CHECK-NEXT:    [[LOWHIGH1:%.+]]  = arith.andi [[HIGH1]], [[MASK]] : i32
+// CHECK-NEXT:    [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32
+//
+// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32
+// CHECK-DAG      {{%.+}}           = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32
+// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32
+// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32
+//
+// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32
+// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32
+// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32
+//
+// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32
+// CHECK-DAG:     {{%.+}}           = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32
+//
+// CHECK-DAG:     {{%.+}}           = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32
+//
+// CHECK:         [[RESHIGH0:%.+]]  = arith.shli {{%.+}}, [[C16]] : i32
+// CHECK-NEXT:    [[RES0:%.+]]      = arith.ori {{%.+}}, [[RESHIGH0]] : i32
+// CHECK-NEXT:    [[RESHIGH1:%.+]]  = arith.shli {{%.+}}, [[C16]] : i32
+// CHECK-NEXT:    [[RES1:%.+]]      = arith.ori {{%.+}}, [[RESHIGH1]] : i32
+// CHECK-NEXT:    [[VZ:%.+]]        = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]]      = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]      = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
+    %m = arith.muli %a, %b : i64
+    return %m : i64
+}
+
+// CHECK-LABEL: func.func @muli_vector
+// CHECK-SAME:    ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:       return {{%.+}} : vector<3x2xi32>
+func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+    %m = arith.muli %a, %b : vector<3xi64>
+    return %m : vector<3xi64>
+}

diff  --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
new file mode 100644
index 0000000000000..7a56ed929d74b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
@@ -0,0 +1,97 @@
+// Check that the wide integer multiplication emulation produces the same result as wide
+// multiplication. Emulate i16 ops with i8 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN:   FileCheck %s --match-full-lines --check-prefix=WIDE
+
+// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN:   FileCheck %s --match-full-lines --check-prefix=EMULATED
+
+func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+  %res = arith.muli %lhs, %rhs : i16
+  vector.print %res : i16
+  return
+}
+
+func.func @entry() {
+  %cst0 = arith.constant 0 : i16
+  %cst1 = arith.constant 1 : i16
+  %cst_1 = arith.constant -1 : i16
+  %cst_3 = arith.constant -3 : i16
+
+  %cst13 = arith.constant 13 : i16
+  %cst37 = arith.constant 37 : i16
+  %cst42 = arith.constant 42 : i16
+
+  %cst256 = arith.constant 256 : i16
+  %cst_i16_max = arith.constant 32767 : i16
+  %cst_i16_min = arith.constant -32768 : i16
+
+  // WIDE: 0
+  // EMULATED: ( 0, 0 )
+  func.call @check_muli(%cst0, %cst0) : (i16, i16) -> ()
+  // WIDE-NEXT: 0
+  // EMULATED-NEXT: ( 0, 0 )
+  func.call @check_muli(%cst0, %cst1) : (i16, i16) -> ()
+  // WIDE-NEXT: 1
+  // EMULATED-NEXT: ( 1, 0 )
+  func.call @check_muli(%cst1, %cst1) : (i16, i16) -> ()
+  // WIDE-NEXT: -1
+  // EMULATED-NEXT: ( -1, -1 )
+  func.call @check_muli(%cst1, %cst_1) : (i16, i16) -> ()
+  // WIDE-NEXT: 1
+  // EMULATED-NEXT: ( 1, 0 )
+  func.call @check_muli(%cst_1, %cst_1) : (i16, i16) -> ()
+  // WIDE-NEXT: -3
+  // EMULATED-NEXT: ( -3, -1 )
+  func.call @check_muli(%cst1, %cst_3) : (i16, i16) -> ()
+
+  // WIDE-NEXT: 169
+  // EMULATED-NEXT: ( -87, 0 )
+  func.call @check_muli(%cst13, %cst13) : (i16, i16) -> ()
+  // WIDE-NEXT: 481
+  // EMULATED-NEXT: ( -31, 1 )
+  func.call @check_muli(%cst13, %cst37) : (i16, i16) -> ()
+  // WIDE-NEXT: 1554
+  // EMULATED-NEXT: ( 18, 6 )
+  func.call @check_muli(%cst37, %cst42) : (i16, i16) -> ()
+
+  // WIDE-NEXT: -256
+  // EMULATED-NEXT: ( 0, -1 )
+  func.call @check_muli(%cst_1, %cst256) : (i16, i16) -> ()
+  // WIDE-NEXT: 3328
+  // EMULATED-NEXT: ( 0, 13 )
+  func.call @check_muli(%cst256, %cst13) : (i16, i16) -> ()
+  // WIDE-NEXT: 9472
+  // EMULATED-NEXT: ( 0, 37 )
+  func.call @check_muli(%cst256, %cst37) : (i16, i16) -> ()
+  // WIDE-NEXT: -768
+  // EMULATED-NEXT: ( 0, -3 )
+  func.call @check_muli(%cst256, %cst_3) : (i16, i16) -> ()
+
+  // WIDE-NEXT: 32755
+  // EMULATED-NEXT: ( -13, 127 )
+  func.call @check_muli(%cst13, %cst_i16_max) : (i16, i16) -> ()
+  // WIDE-NEXT: -32768
+  // EMULATED-NEXT: ( 0, -128 )
+  func.call @check_muli(%cst_i16_min, %cst37) : (i16, i16) -> ()
+
+  // WIDE-NEXT: 1
+  // EMULATED-NEXT: ( 1, 0 )
+  func.call @check_muli(%cst_i16_max, %cst_i16_max) : (i16, i16) -> ()
+  // WIDE-NEXT: -32768
+  // EMULATED-NEXT: ( 0, -128 )
+  func.call @check_muli(%cst_i16_min, %cst13) : (i16, i16) -> ()
+  // WIDE-NEXT: 0
+  // EMULATED-NEXT: ( 0, 0 )
+  func.call @check_muli(%cst_i16_min, %cst_i16_min) : (i16, i16) -> ()
+
+  return
+}

```

More information about the Mlir-commits mailing list