[Mlir-commits] [mlir] f4bef78 - Add narrow type emulation pattern for vector.transfer_read
Hanhan Wang
llvmlistbot at llvm.org
Tue Aug 29 13:15:47 PDT 2023
Author: yzhang93
Date: 2023-08-29T13:15:19-07:00
New Revision: f4bef787bcb2536c674ff0967f22ac2bd74bd571
URL: https://github.com/llvm/llvm-project/commit/f4bef787bcb2536c674ff0967f22ac2bd74bd571
DIFF: https://github.com/llvm/llvm-project/commit/f4bef787bcb2536c674ff0967f22ac2bd74bd571.diff
LOG: Add narrow type emulation pattern for vector.transfer_read
Reviewed By: mravishankar, hanchung
Differential Revision: https://reviews.llvm.org/D158757
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 01ee43354a711e..b2b7bfc5e4437c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,9 +36,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
- Type newElementType = sourceType.getElementType();
+ Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -81,16 +81,73 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto srcElementType = sourceType.getElementType();
- auto numElements =
- static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
+ auto numElements = (origElements + scale - 1) / scale;
auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
+ loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
- numElements *= scale;
- auto castType = VectorType::get(numElements, oldElementType);
- auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
+ auto bitCast =
+ rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+
+ rewriter.replaceOp(op, bitCast->getResult(0));
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertVectorTransferRead
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorTransferRead final
+ : OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
+ Type oldElementType = op.getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+ int scale = dstBits / srcBits;
+
+ auto origElements = op.getVectorType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
+ adaptor.getPadding());
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
+
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newReadType = VectorType::get(numElements, newElementType);
+
+ auto newRead = rewriter.create<vector::TransferReadOp>(
+ loc, newReadType, adaptor.getSource(),
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
+ newPadding);
+
+ auto bitCast =
+ rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
rewriter.replaceOp(op, bitCast->getResult(0));
return success();
@@ -107,5 +164,6 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
+ patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index e3c6b098b70ba4..6fcea33ddc952f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -79,3 +79,32 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
+ %c0 = arith.constant 0 : i4
+ %0 = memref.alloc() : memref<3x8xi4>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+ memref<3x8xi4>, vector<8xi4>
+ return %1 : vector<8xi4>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_transfer_read_i4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[CONST:.+]] = arith.constant 0 : i4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i8
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_transfer_read_i4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[CONST:.+]] = arith.constant 0 : i4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i32
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
More information about the Mlir-commits
mailing list