[Mlir-commits] [mlir] add vector subbyte store support (PR #70293)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 30 22:14:15 PDT 2023
https://github.com/saienduri updated https://github.com/llvm/llvm-project/pull/70293
>From fbcbc8b2e76f1062515e317ca7ea19668d39d14f Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Wed, 25 Oct 2023 22:43:37 -0700
Subject: [PATCH 1/3] add vector store support
---
.../Transforms/VectorEmulateNarrowType.cpp | 69 ++++++++++++++++++-
.../Vector/vector-emulate-narrow-type.mlir | 68 ++++++++++++++++++
2 files changed, 135 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 94300291dcd7d23..fa68debbf8e2e85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -33,6 +33,70 @@ using namespace mlir;
namespace {
+//===----------------------------------------------------------------------===//
+// ConvertVectorStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getValueToStore().getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+ int scale = dstBits / srcBits;
+
+ // Adjust the number of elements to store when emulating narrow types.
+ // Here only the 1-D vector load is considered, and the N-D memref types
+ // should be linearized.
+ // For example, to emulate i4 to i8, the following op:
+ //
+ // vector.store %arg1, %0[%arg2, %arg3] :memref<4x8xi4>, vector<8xi4>
+ //
+ // can be replaced with
+ //
+ // vector.store %bitcast_arg1, %alloc[%linear_index] : memref<16xi8>,
+ // vector<4xi8>
+
+ auto origElements = op.getValueToStore().getType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+
+ auto numElements = (origElements + scale - 1) / scale;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast->getResult(0), adaptor.getBase(),
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
@@ -588,8 +652,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertVectorLoad, ConvertVectorTransferRead, ConvertVectorStore>(
+ typeConverter, patterns.getContext());
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6fcea33ddc952fe..46618e1d315ecd0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -108,3 +108,71 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xi8>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+ return
+}
+
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_store_i8
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4x8xi8>
+// CHECK: vector.store %[[ARG0]], %[[ALLOC:.+]][%[[ARG1]], %[[ARG2]]] : memref<4x8xi8>, vector<8xi8>
+
+// CHECK32-DAG: affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32: func @vector_store_i8
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<8xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi8> to vector<2xi32>
+// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<8xi32>, vector<2xi32
+
+// -----
+
+func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xi4>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi4>, vector<8xi4>
+ return
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_store_i4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_store_i4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
+func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+ %0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
+ vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>
+ return
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK: func @vector_store_i4_dynamic
+// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
+// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32: func @vector_store_i4_dynamic
+// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
+// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>
>From 1f4dd5cdf0d755b2ecf0fe334254387353a01466 Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Thu, 26 Oct 2023 16:13:18 -0700
Subject: [PATCH 2/3] fix mlir FileCheck test
---
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 46618e1d315ecd0..93d3488b5ed8a2e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -162,6 +162,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK: func @vector_store_i4_dynamic
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: index)
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
>From 0e72f90d998213c3ea7fa3741d0bee4f0a029c6c Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Mon, 30 Oct 2023 22:13:56 -0700
Subject: [PATCH 3/3] minor change
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fa68debbf8e2e85..03c0caf3dcc4088 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -85,7 +85,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = (origElements + scale - 1) / scale;
+ auto numElements = origElements / scale;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
More information about the Mlir-commits
mailing list