[Mlir-commits] [mlir] [mlir] Add support for vector.store sub-byte emulation. (PR #70293)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 1 16:06:03 PDT 2023


https://github.com/saienduri updated https://github.com/llvm/llvm-project/pull/70293

>From f78c1e63c06cb7cca95accd9102c947fef3b9e77 Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Wed, 25 Oct 2023 22:43:37 -0700
Subject: [PATCH 1/3] add vector store support

fix mlir FileCheck test

minor change

minor change
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 68 +++++++++++++++++-
 .../Vector/vector-emulate-narrow-type.mlir    | 69 +++++++++++++++++++
 2 files changed, 135 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3d65123373109b3..bf858d7d1ab97fb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -34,6 +34,70 @@ using namespace mlir;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getValueToStore().getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    // Adjust the number of elements to store when emulating narrow types.
+    // Here only the 1-D vector load is considered, and the N-D memref types
+    // should be linearized.
+    // For example, to emulate i4 to i8, the following op:
+    //
+    // vector.store %arg1, %0[%arg2, %arg3] :memref<4x8xi4>, vector<8xi4>
+    //
+    // can be replaced with
+    //
+    // vector.store %bitcast_arg1, %alloc[%linear_index] : memref<16xi8>,
+    // vector<4xi8>
+
+    auto origElements = op.getValueToStore().getType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+
+    auto numElements = origElements / scale;
+    auto bitCast = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements, newElementType),
+        op.getValueToStore());
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        op, bitCast.getResult(), adaptor.getBase(),
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
@@ -755,8 +819,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
-               ConvertVectorTransferRead>(typeConverter, patterns.getContext());
+  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore, ConvertVectorTransferRead>(
+      typeConverter, patterns.getContext());
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index e1d6c3be494713e..73794d7ac146772 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -350,3 +350,72 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 // CHECK32-SAME:     memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
 //      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
 //      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
+
+// -----
+
+func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xi8>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+    return
+}
+
+// Expect no conversions, i8 is supported.
+//      CHECK: func @vector_store_i8
+//      CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4x8xi8>
+//      CHECK: vector.store %[[ARG0]], %[[ALLOC:.+]][%[[ARG1]], %[[ARG2]]] : memref<4x8xi8>, vector<8xi8>
+
+//  CHECK32-DAG: affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+//      CHECK32: func @vector_store_i8
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<8xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi8> to vector<2xi32>
+//      CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<8xi32>, vector<2xi32
+
+// -----
+
+func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xi4>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi4>, vector<8xi4>
+    return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_store_i4
+//      CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+//      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+//      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_store_i4
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
+//      CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
+func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+    %0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
+    vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>
+    return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//      CHECK: func @vector_store_i4_dynamic
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: index)
+//      CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
+//      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+//      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//      CHECK32: func @vector_store_i4_dynamic
+//      CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
+//      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
+//      CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>

>From 246d71f78d5883dac54ab69a499e8c7207781d18 Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Wed, 1 Nov 2023 16:02:10 -0700
Subject: [PATCH 2/3] fix lit test

---
 mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 73794d7ac146772..b54045ef605ab8a 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -414,6 +414,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 //  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func @vector_store_i4_dynamic
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: index)
 //      CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
 //      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]

>From ed9b236953af95f27446d74cb2ef962bdc57cb45 Mon Sep 17 00:00:00 2001
From: saienduri <enduri.sai at gmail.com>
Date: Wed, 1 Nov 2023 16:05:39 -0700
Subject: [PATCH 3/3] fix formatting

---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf858d7d1ab97fb..5c202cd905434e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -819,8 +819,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore, ConvertVectorTransferRead>(
-      typeConverter, patterns.getContext());
+  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+               ConvertVectorTransferRead>(typeConverter, patterns.getContext());
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(



More information about the Mlir-commits mailing list