[Mlir-commits] [mlir] 3afd351 - [mlir][arith] Support wide int cast emulation

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 15 08:36:58 PDT 2022

Author: Jakub Kuderski
Date: 2022-09-15T11:34:58-04:00
New Revision: 3afd351b5fd9006932857a6daf42cbd1c79c4a22

URL: https://github.com/llvm/llvm-project/commit/3afd351b5fd9006932857a6daf42cbd1c79c4a22
DIFF: https://github.com/llvm/llvm-project/commit/3afd351b5fd9006932857a6daf42cbd1c79c4a22.diff

LOG: [mlir][arith] Support wide int cast emulation

Add support for `arith.extsi`, `arith.extui`, and `arith.trunci` ops.

Tested by checking the results for all 16-bit inputs when emulating i16 with i8.

Reviewed By: antiagainst, Mogball

Differential Revision: https://reviews.llvm.org/D133612




diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index cdecf5485e95f..7716f618d9e5e 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -61,7 +61,7 @@ static Type reduceInnermostDim(VectorType type) {
 static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
                                  Location loc, Value input,
                                  int64_t lastOffset) {
-  llvm::ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+  ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
   assert(lastOffset < shape.back() && "Offset out of bounds");
   // Scalarize the result in case of 1D vectors.
@@ -87,13 +87,45 @@ extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
           extractLastDimSlice(rewriter, loc, input, 1)};
+// Performs a vector shape cast to drop the trailing x1 dimension. If the
+// `input` is a scalar, this is a noop.
+static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
+                               Location loc, Value input) {
+  auto vecTy = input.getType().dyn_cast<VectorType>();
+  if (!vecTy)
+    return input;
+  // Shape cast to drop the last x1 dimention.
+  ArrayRef<int64_t> shape = vecTy.getShape();
+  assert(shape.size() >= 2 && "Expected vector with at list two dims");
+  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
+  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
+  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
+// Performs a vector shape cast to append an x1 dimension. If the
+// `input` is a scalar, this is a noop.
+static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
+                         Value input) {
+  auto vecTy = input.getType().dyn_cast<VectorType>();
+  if (!vecTy)
+    return input;
+  // Add a trailing x1 dim.
+  auto newShape = llvm::to_vector(vecTy.getShape());
+  newShape.push_back(1);
+  auto newTy = VectorType::get(newShape, vecTy.getElementType());
+  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
 // Inserts the `source` vector slice into the `dest` vector at offset
 // `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a
 // 1D vector.
 static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
                                 Location loc, Value source, Value dest,
                                 int64_t lastOffset) {
-  llvm::ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+  ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
   assert(lastOffset < shape.back() && "Offset out of bounds");
   // Handle scalar source.
@@ -228,6 +260,104 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
+// ConvertExtSI
+struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()
+                     ->convertType(op.getType())
+                     .dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "unsupported type");
+    Type newResultComponentTy = reduceInnermostDim(newTy);
+    // Sign-extend the input value to determine the low half of the result.
+    // Then, check if the low half is negative, and sign-extend the comparison
+    // result to get the high half.
+    Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
+    Value extended = rewriter.createOrFold<arith::ExtSIOp>(
+        loc, newResultComponentTy, newOperand);
+    Value operandZeroCst = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(newResultComponentTy));
+    Value signBit = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
+    Value signValue =
+        rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {extended, signValue});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+// ConvertExtUI
+struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()
+                     ->convertType(op.getType())
+                     .dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "unsupported type");
+    Type newResultComponentTy = reduceInnermostDim(newTy);
+    // Zero-extend the input value to determine the low half of the result.
+    // The high half is always zero.
+    Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
+    Value extended = rewriter.createOrFold<arith::ExtUIOp>(
+        loc, newResultComponentTy, newOperand);
+    Value zeroCst = rewriter.create<arith::ConstantOp>(
+        op->getLoc(), rewriter.getZeroAttr(newTy));
+    Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
+    rewriter.replaceOp(op, newRes);
+    return success();
+  }
+// ConvertTruncI
+struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Check if the result type is legal for this target. Currently, we do not
+    // support truncation to types wider than supported by the target.
+    if (!getTypeConverter()->isLegal(op.getType()))
+      return rewriter.notifyMatchFailure(loc,
+                                         "unsupported truncation result type");
+    // Discard the high half of the input. Truncate the low half, if necessary.
+    Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
+    extracted = dropTrailingX1Dim(rewriter, loc, extracted);
+    Value truncated =
+        rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
+    rewriter.replaceOp(op, truncated);
+    return success();
+  }
 // Pass Definition
@@ -335,6 +465,12 @@ void arith::populateWideIntEmulationPatterns(
   populateReturnOpTypeConversionPattern(patterns, typeConverter);
   // Populate `arith.*` conversion patterns.
-  patterns.add<ConvertConstant, ConvertAddI>(typeConverter,
-                                             patterns.getContext());
+  patterns.add<
+      // Misc ops.
+      ConvertConstant,
+      // Binary ops.
+      ConvertAddI,
+      // Extension and truncation ops.
+      ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
+                                                 patterns.getContext());

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index 472417681b58a..ae4c8126ae192 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -111,3 +111,97 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
     %x = arith.addi %a, %b : vector<4xi64>
     return %x : vector<4xi64>
+// CHECK-LABEL: func @extsi_scalar
+// CHECK-SAME:    ([[ARG:%.+]]: i16) -> vector<2xi32>
+// CHECK-NEXT:    [[EXT:%.+]]  = arith.extsi [[ARG]] : i16 to i32
+// CHECK-NEXT:    [[SZ:%.+]]   = arith.constant 0 : i32
+// CHECK-NEXT:    [[SB:%.+]]   = arith.cmpi slt, [[EXT]], [[SZ]] : i32
+// CHECK-NEXT:    [[SV:%.+]]   = arith.extsi [[SB]] : i1 to i32
+// CHECK-NEXT:    [[VZ:%.+]]   = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]] = vector.insert [[SV]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK:         return [[INS1]] : vector<2xi32>
+func.func @extsi_scalar(%a : i16) -> i64 {
+    %r = arith.extsi %a : i16 to i64
+    return %r : i64
+// CHECK-LABEL: func @extsi_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16>
+// CHECK-NEXT:    [[EXT:%.+]]   = arith.extsi [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32>
+// CHECK-NEXT:    [[CSTE:%.+]]  = arith.constant dense<0> : vector<3x1xi32>
+// CHECK-NEXT:    [[CMP:%.+]]   = arith.cmpi slt, [[EXT]], [[CSTE]] : vector<3x1xi32>
+// CHECK-NEXT:    [[HIGH:%.+]]  = arith.extsi [[CMP]] : vector<3x1xi1> to vector<3x1xi32>
+// CHECK-NEXT:    [[CSTZ:%.+]]  = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT:    [[INS0:%.+]]  = vector.insert_strided_slice [[EXT]], [[CSTZ]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]  = vector.insert_strided_slice [[HIGH]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<3x2xi32>
+func.func @extsi_vector(%a : vector<3xi16>) -> vector<3xi64> {
+    %r = arith.extsi %a : vector<3xi16> to vector<3xi64>
+    return %r : vector<3xi64>
+// CHECK-LABEL: func @extui_scalar1
+// CHECK-SAME:    ([[ARG:%.+]]: i16) -> vector<2xi32>
+// CHECK-NEXT:    [[EXT:%.+]]  = arith.extui [[ARG]] : i16 to i32
+// CHECK-NEXT:    [[VZ:%.+]]   = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK:         return [[INS0]] : vector<2xi32>
+func.func @extui_scalar1(%a : i16) -> i64 {
+    %r = arith.extui %a : i16 to i64
+    return %r : i64
+// CHECK-LABEL: func @extui_scalar2
+// CHECK-SAME:    ([[ARG:%.+]]: i32) -> vector<2xi32>
+// CHECK-NEXT:    [[VZ:%.+]]   = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]] = vector.insert [[ARG]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK:         return [[INS0]] : vector<2xi32>
+func.func @extui_scalar2(%a : i32) -> i64 {
+    %r = arith.extui %a : i32 to i64
+    return %r : i64
+// CHECK-LABEL: func @extui_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16>
+// CHECK-NEXT:    [[EXT:%.+]]   = arith.extui [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32>
+// CHECK-NEXT:    [[CST:%.+]]   = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT:    [[INS0:%.+]]  = vector.insert_strided_slice [[EXT]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK:         return [[INS0]] : vector<3x2xi32>
+func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
+    %r = arith.extui %a : vector<3xi16> to vector<3xi64>
+    return %r : vector<3xi64>
+// CHECK-LABEL: func @trunci_scalar1
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> i32
+// CHECK-NEXT:    [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT:    return [[EXT]] : i32
+func.func @trunci_scalar1(%a : i64) -> i32 {
+    %b = arith.trunci %a : i64 to i32
+    return %b : i32
+// CHECK-LABEL: func @trunci_scalar2
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> i16
+// CHECK-NEXT:    [[EXTR:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT:    [[TRNC:%.+]] = arith.trunci [[EXTR]] : i32 to i16
+// CHECK-NEXT:    return [[TRNC]] : i16
+func.func @trunci_scalar2(%a : i64) -> i16 {
+    %b = arith.trunci %a : i64 to i16
+    return %b : i16
+// CHECK-LABEL: func @trunci_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xi16>
+// CHECK-NEXT:    [[EXTR:%.+]]  = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[EXTR]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT:    [[TRNC:%.+]]  = arith.trunci [[SHAPE]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT:    return [[TRNC]] : vector<3xi16>
+func.func @trunci_vector(%a : vector<3xi64>) -> vector<3xi16> {
+    %b = arith.trunci %a : vector<3xi64> to vector<3xi16>
+    return %b : vector<3xi16>


