[Mlir-commits] [mlir] 92bcb8c - [mlir][arith] Add `index_cast` and `index_castui` support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Thu Nov 17 11:05:43 PST 2022
Author: Jakub Kuderski
Date: 2022-11-17T14:04:17-05:00
New Revision: 92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c
URL: https://github.com/llvm/llvm-project/commit/92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c
DIFF: https://github.com/llvm/llvm-project/commit/92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c.diff
LOG: [mlir][arith] Add `index_cast` and `index_castui` support to WIE
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D138225
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 50069ae2b855e..ad42001996ec4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -598,6 +598,86 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
}
};
+// Convert IndexCast ops
+//===----------------------------------------------------------------------===//
+
+/// Returns true iff the type is `index` or `vector<...index>`.
+static bool isIndexOrIndexVector(Type type) {
+ if (type.isa<IndexType>())
+ return true;
+
+ if (auto vectorTy = type.dyn_cast<VectorType>())
+ if (vectorTy.getElementType().isa<IndexType>())
+ return true;
+
+ return false;
+}
+
+template <typename CastOp>
+struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
+ using OpConversionPattern<CastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type resultType = op.getType();
+ if (!isIndexOrIndexVector(resultType))
+ return failure();
+
+ Location loc = op.getLoc();
+ Type inType = op.getIn().getType();
+ auto newInTy = this->getTypeConverter()
+ ->convertType(inType)
+ .template dyn_cast_or_null<VectorType>();
+ if (!newInTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", inType));
+
+ // Discard the high half of the input truncating the original value.
+ Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
+ extracted = dropTrailingX1Dim(rewriter, loc, extracted);
+ rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
+ return success();
+ }
+};
+
+template <typename CastOp, typename ExtensionOp>
+struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
+ using OpConversionPattern<CastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type inType = op.getIn().getType();
+ if (!isIndexOrIndexVector(inType))
+ return failure();
+
+ Location loc = op.getLoc();
+ auto *typeConverter =
+ this->template getTypeConverter<arith::WideIntEmulationConverter>();
+
+ Type resultType = op.getType();
+ auto newTy = typeConverter->convertType(resultType)
+ .template dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", resultType));
+
+ // Emit an index cast over the matching narrow type.
+ Type narrowTy =
+ rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
+ if (auto vecTy = resultType.dyn_cast<VectorType>())
+ narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
+
+ // Sign or zero-extend the result. Let the matching conversion pattern
+ // legalize the extension op.
+ Value underlyingVal =
+ rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
+ rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertSelect
//===----------------------------------------------------------------------===//
@@ -841,8 +921,7 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
// Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
// Perform as many ops over the narrow integer type as possible and let the
// other emulation patterns convert the rest.
- Value elemZero =
- createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
+ Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
Value signBit = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
signBit = dropTrailingX1Dim(rewriter, loc, signBit);
@@ -862,7 +941,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
// Use original arguments to create the right shift.
- Value shrui = rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
+ Value shrui =
+ rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
// Handle shifting by zero. This is necessary when the `signBits` shift is
@@ -870,7 +950,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
rhsElem0, elemZero);
isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
+ shrsi);
return success();
}
@@ -1045,6 +1126,11 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>,
// Extension and truncation ops.
- ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
- patterns.getContext());
+ ConvertExtSI, ConvertExtUI, ConvertTruncI,
+ // Cast ops.
+ ConvertIndexCastIntToIndex<arith::IndexCastOp>,
+ ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
+ ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
+ ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index bfef4162ef8c5..0f85e7a859386 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -365,6 +365,102 @@ func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
return %r : vector<3xi64>
}
+// CHECK-LABEL: func @index_cast_int_to_index_scalar
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
+// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[EXT]] : i32 to index
+// CHECK-NEXT: return [[RES]] : index
+func.func @index_cast_int_to_index_scalar(%a : i64) -> index {
+ %r = arith.index_cast %a : i64 to index
+ return %r : index
+}
+
+// CHECK-LABEL: func @index_cast_int_to_index_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
+// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[SHAPE]] : vector<3xi32> to vector<3xindex>
+// CHECK-NEXT: return [[RES]] : vector<3xindex>
+func.func @index_cast_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
+ %r = arith.index_cast %a : vector<3xi64> to vector<3xindex>
+ return %r : vector<3xindex>
+}
+
+// CHECK-LABEL: func @index_castui_int_to_index_scalar
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
+// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[EXT]] : i32 to index
+// CHECK-NEXT: return [[RES]] : index
+func.func @index_castui_int_to_index_scalar(%a : i64) -> index {
+ %r = arith.index_castui %a : i64 to index
+ return %r : index
+}
+
+// CHECK-LABEL: func @index_castui_int_to_index_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
+// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[SHAPE]] : vector<3xi32> to vector<3xindex>
+// CHECK-NEXT: return [[RES]] : vector<3xindex>
+func.func @index_castui_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
+ %r = arith.index_castui %a : vector<3xi64> to vector<3xindex>
+ return %r : vector<3xindex>
+}
+
+// CHECK-LABEL: func @index_cast_index_to_int_scalar
+// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
+// CHECK-NEXT: [[CAST:%.+]] = arith.index_cast [[ARG]] : index to i32
+// CHECK-NEXT: [[C0I32:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[CAST]], [[C0I32]] : i32
+// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[NEG]] : i1 to i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[EXT]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @index_cast_index_to_int_scalar(%a : index) -> i64 {
+ %r = arith.index_cast %a : index to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @index_cast_index_to_int_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
+// CHECK-NEXT: arith.index_cast [[ARG]] : vector<3xindex> to vector<3xi32>
+// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: arith.constant dense<0> : vector<3x1xi32>
+// CHECK-NEXT: arith.cmpi slt
+// CHECK-NEXT: arith.extsi
+// CHECK-NEXT: arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT: vector.insert_strided_slice
+// CHECK-NEXT: vector.insert_strided_slice
+// CHECK-NEXT: return {{%.+}} : vector<3x2xi32>
+func.func @index_cast_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
+ %r = arith.index_cast %a : vector<3xindex> to vector<3xi64>
+ return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func @index_castui_index_to_int_scalar
+// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
+// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : index to i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[RES]] : vector<2xi32>
+func.func @index_castui_index_to_int_scalar(%a : index) -> i64 {
+ %r = arith.index_castui %a : index to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @index_castui_index_to_int_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
+// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : vector<3xindex> to vector<3xi32>
+// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[CAST]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.insert_strided_slice [[SHAPE]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT: return [[RES]] : vector<3x2xi32>
+func.func @index_castui_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
+ %r = arith.index_castui %a : vector<3xindex> to vector<3xi64>
+ return %r : vector<3xi64>
+}
+
// CHECK-LABEL: func @trunci_scalar1
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
More information about the Mlir-commits
mailing list