[Mlir-commits] [mlir] 3afd351 - [mlir][arith] Support wide int cast emulation
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 15 08:36:58 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-15T11:34:58-04:00
New Revision: 3afd351b5fd9006932857a6daf42cbd1c79c4a22
URL: https://github.com/llvm/llvm-project/commit/3afd351b5fd9006932857a6daf42cbd1c79c4a22
DIFF: https://github.com/llvm/llvm-project/commit/3afd351b5fd9006932857a6daf42cbd1c79c4a22.diff
LOG: [mlir][arith] Support wide int cast emulation
Add support for `arith.extsi`, `arith.extui`, and `arith.trunci` ops.
Tested by checking the results for all 16-bit inputs when emulating i16 with i8.
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D133612
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index cdecf5485e95f..7716f618d9e5e 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -61,7 +61,7 @@ static Type reduceInnermostDim(VectorType type) {
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t lastOffset) {
- llvm::ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Scalarize the result in case of 1D vectors.
@@ -87,13 +87,45 @@ extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
extractLastDimSlice(rewriter, loc, input, 1)};
}
+// Performs a vector shape cast to drop the trailing x1 dimension. If the
+// `input` is a scalar, this is a noop.
+static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
+ auto vecTy = input.getType().dyn_cast<VectorType>();
+ if (!vecTy)
+ return input;
+
+ // Shape cast to drop the last x1 dimention.
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ assert(shape.size() >= 2 && "Expected vector with at list two dims");
+ assert(shape.back() == 1 && "Expected the last vector dim to be x1");
+
+ auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
+ return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
+}
+
+// Performs a vector shape cast to append an x1 dimension. If the
+// `input` is a scalar, this is a noop.
+static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
+ Value input) {
+ auto vecTy = input.getType().dyn_cast<VectorType>();
+ if (!vecTy)
+ return input;
+
+ // Add a trailing x1 dim.
+ auto newShape = llvm::to_vector(vecTy.getShape());
+ newShape.push_back(1);
+ auto newTy = VectorType::get(newShape, vecTy.getElementType());
+ return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
+}
+
// Inserts the `source` vector slice into the `dest` vector at offset
// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a
// 1D vector.
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value source, Value dest,
int64_t lastOffset) {
- llvm::ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Handle scalar source.
@@ -228,6 +260,104 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertExtSI
+//===----------------------------------------------------------------------===//
+
+struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()
+ ->convertType(op.getType())
+ .dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+ Type newResultComponentTy = reduceInnermostDim(newTy);
+
+ // Sign-extend the input value to determine the low half of the result.
+ // Then, check if the low half is negative, and sign-extend the comparison
+ // result to get the high half.
+ Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
+ Value extended = rewriter.createOrFold<arith::ExtSIOp>(
+ loc, newResultComponentTy, newOperand);
+ Value operandZeroCst = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(newResultComponentTy));
+ Value signBit = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
+ Value signValue =
+ rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
+
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {extended, signValue});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertExtUI
+//===----------------------------------------------------------------------===//
+
+struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()
+ ->convertType(op.getType())
+ .dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+ Type newResultComponentTy = reduceInnermostDim(newTy);
+
+ // Zero-extend the input value to determine the low half of the result.
+ // The high half is always zero.
+ Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
+ Value extended = rewriter.createOrFold<arith::ExtUIOp>(
+ loc, newResultComponentTy, newOperand);
+ Value zeroCst = rewriter.create<arith::ConstantOp>(
+ op->getLoc(), rewriter.getZeroAttr(newTy));
+ Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
+ rewriter.replaceOp(op, newRes);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertTruncI
+//===----------------------------------------------------------------------===//
+
+struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Check if the result type is legal for this target. Currently, we do not
+ // support truncation to types wider than supported by the target.
+ if (!getTypeConverter()->isLegal(op.getType()))
+ return rewriter.notifyMatchFailure(loc,
+ "unsupported truncation result type");
+
+ // Discard the high half of the input. Truncate the low half, if necessary.
+ Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
+ extracted = dropTrailingX1Dim(rewriter, loc, extracted);
+ Value truncated =
+ rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
+ rewriter.replaceOp(op, truncated);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -335,6 +465,12 @@ void arith::populateWideIntEmulationPatterns(
populateReturnOpTypeConversionPattern(patterns, typeConverter);
// Populate `arith.*` conversion patterns.
- patterns.add<ConvertConstant, ConvertAddI>(typeConverter,
- patterns.getContext());
+ patterns.add<
+ // Misc ops.
+ ConvertConstant,
+ // Binary ops.
+ ConvertAddI,
+ // Extension and truncation ops.
+ ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index 472417681b58a..ae4c8126ae192 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -111,3 +111,97 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
%x = arith.addi %a, %b : vector<4xi64>
return %x : vector<4xi64>
}
+
+// CHECK-LABEL: func @extsi_scalar
+// CHECK-SAME: ([[ARG:%.+]]: i16) -> vector<2xi32>
+// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[ARG]] : i16 to i32
+// CHECK-NEXT: [[SZ:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[SB:%.+]] = arith.cmpi slt, [[EXT]], [[SZ]] : i32
+// CHECK-NEXT: [[SV:%.+]] = arith.extsi [[SB]] : i1 to i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SV]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK: return [[INS1]] : vector<2xi32>
+func.func @extsi_scalar(%a : i16) -> i64 {
+ %r = arith.extsi %a : i16 to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @extsi_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16>
+// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32>
+// CHECK-NEXT: [[CSTE:%.+]] = arith.constant dense<0> : vector<3x1xi32>
+// CHECK-NEXT: [[CMP:%.+]] = arith.cmpi slt, [[EXT]], [[CSTE]] : vector<3x1xi32>
+// CHECK-NEXT: [[HIGH:%.+]] = arith.extsi [[CMP]] : vector<3x1xi1> to vector<3x1xi32>
+// CHECK-NEXT: [[CSTZ:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert_strided_slice [[EXT]], [[CSTZ]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[HIGH]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<3x2xi32>
+func.func @extsi_vector(%a : vector<3xi16>) -> vector<3xi64> {
+ %r = arith.extsi %a : vector<3xi16> to vector<3xi64>
+ return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func @extui_scalar1
+// CHECK-SAME: ([[ARG:%.+]]: i16) -> vector<2xi32>
+// CHECK-NEXT: [[EXT:%.+]] = arith.extui [[ARG]] : i16 to i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK: return [[INS0]] : vector<2xi32>
+func.func @extui_scalar1(%a : i16) -> i64 {
+ %r = arith.extui %a : i16 to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @extui_scalar2
+// CHECK-SAME: ([[ARG:%.+]]: i32) -> vector<2xi32>
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ARG]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK: return [[INS0]] : vector<2xi32>
+func.func @extui_scalar2(%a : i32) -> i64 {
+ %r = arith.extui %a : i32 to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @extui_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16>
+// CHECK-NEXT: [[EXT:%.+]] = arith.extui [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32>
+// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert_strided_slice [[EXT]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK: return [[INS0]] : vector<3x2xi32>
+func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
+ %r = arith.extui %a : vector<3xi16> to vector<3xi64>
+ return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func @trunci_scalar1
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32
+// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: return [[EXT]] : i32
+func.func @trunci_scalar1(%a : i64) -> i32 {
+ %b = arith.trunci %a : i64 to i32
+ return %b : i32
+}
+
+// CHECK-LABEL: func @trunci_scalar2
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i16
+// CHECK-NEXT: [[EXTR:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: [[TRNC:%.+]] = arith.trunci [[EXTR]] : i32 to i16
+// CHECK-NEXT: return [[TRNC]] : i16
+func.func @trunci_scalar2(%a : i64) -> i16 {
+ %b = arith.trunci %a : i64 to i16
+ return %b : i16
+}
+
+// CHECK-LABEL: func @trunci_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xi16>
+// CHECK-NEXT: [[EXTR:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXTR]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT: [[TRNC:%.+]] = arith.trunci [[SHAPE]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT: return [[TRNC]] : vector<3xi16>
+func.func @trunci_vector(%a : vector<3xi64>) -> vector<3xi16> {
+ %b = arith.trunci %a : vector<3xi64> to vector<3xi16>
+ return %b : vector<3xi16>
+}
More information about the Mlir-commits
mailing list