[Mlir-commits] [mlir] 2de936b - [mlir][vector] Fix emulation of "narrow" type `vector.store` (#133231)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 24 10:05:44 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-04-24T18:05:41+01:00
New Revision: 2de936b6eb38e7a37224a97c2a22aa79b9dfb9dc

URL: https://github.com/llvm/llvm-project/commit/2de936b6eb38e7a37224a97c2a22aa79b9dfb9dc
DIFF: https://github.com/llvm/llvm-project/commit/2de936b6eb38e7a37224a97c2a22aa79b9dfb9dc.diff

LOG: [mlir][vector] Fix emulation of "narrow" type `vector.store` (#133231)

Below are two examples of "narrow" `vector.stores`. The first example
  does not require partial stores and hence no RMW stores. This is
  currently emulated correctly.
  ```mlir
  func.func @example_1(%arg0: vector<4xi2>) {
      %0 = memref.alloc() : memref<13xi2>
      %c4 = arith.constant 4 : index
      vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
      return
  }
  ```

  The second example requires a partial (and hence RMW) store due to the
  offset pointing outside the emulated type boundary (`%c3`).
  ```mlir
  func.func @example_2(%arg0: vector<4xi2>) {
      %0 = memref.alloc() : memref<13xi2>
      %c3 = arith.constant 3 : index
      vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
      return
  }
  ```

  This is currently incorrectly emulated as a single "full" store (note
  that the offset is incorrect) instead of partial stores:
  ```mlir
  func.func @example_2(%arg0: vector<4xi2>) {
    %alloc = memref.alloc() : memref<4xi8>
    %0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8>
    %c0 = arith.constant 0 : index
    vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8>
    return
  }
  ```

  The incorrect emulation stems from this simplified (i.e. incomplete)
  calculation of the front padding:
  ```cpp
      std::optional<int64_t> foldedNumFrontPadElems =
          isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
  ```

  Since `isDivisibleInSize` is `true` (i8 / i2 = 4):
    * front padding is set to `0` and, as a result,
    * the input offset (`%c3`) is ignored, and
    * we incorrectly assume that partial stores won't be needed.

  Note that in both examples we are storing `vector<4xi2>` into
  `memref<13xi2>` (note _different_ trailing dims) and hence partial
  stores might in fact be required. The condition above is updated to:
  ```cpp
      std::optional<int64_t> foldedNumFrontPadElems =
          (isDivisibleInSize && trailingDimsMatch)
              ? 0
              : getConstantIntValue(linearizedInfo.intraDataOffset);
  ```

  This change ensures that the input offset is properly taken into
  account, which fixes the issue. It doesn't affect `@example1`.

  Additional comments are added to clarify the current logic.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
    mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 8d4dcb2b27bf9..a560aa1b1e680 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -593,10 +593,19 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto origElements = valueToStore.getType().getNumElements();
     // Note, per-element-alignment was already verified above.
     bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
+    // Do the trailing dim for source and destination match? If yes, then the
+    // corresponding index must be 0.
+    // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
+    // However, that makes some tests fail, so we need to audit first.
+    auto trailingDim = op.getBase().getType().getShape().back();
+    bool trailingDimsMatch =
+        ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
+    // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
+    // non-zero. As such, this is not needed.
     OpFoldResult linearizedIndices;
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
@@ -608,8 +617,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedNumFrontPadElems =
-        isDivisibleInSize ? 0
-                          : getConstantIntValue(linearizedInfo.intraDataOffset);
+        (isDivisibleInSize && trailingDimsMatch)
+            ? 0
+            : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     if (!foldedNumFrontPadElems) {
       return rewriter.notifyMatchFailure(
@@ -619,15 +629,38 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto memrefBase = cast<MemRefValue>(adaptor.getBase());
 
-    // Conditions when atomic RMWs are not needed:
+    // RMWs are not needed when:
+    //  * no _partial_ stores are required.
+    // A partial store is defined as a store in which only a part of the
+    // container element is overwritten, e.g.
+    //
+    //    Dest before (8 bits)
+    //        +----------+
+    //        | 11000000 |
+    //        +----------+
+    //
+    //    Dest after storing 0xF at offset 4 (in bits)
+    //        +----------+
+    //        | 11001111 |
+    //        +----------+
+    //
+    // At a higher level, this translats to:
     // 1. The source vector size (in bits) is a multiple of byte size.
-    // 2. The address of the store is aligned to the emulated width boundary.
+    // 2. The address of the store is aligned to the container type width
+    //    boundary.
+    //
+    // EXAMPLE 1:
+    //  Requires partial store:
+    //    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
     //
-    // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
-    // need unaligned emulation because the store address is aligned and the
-    // source is a whole byte.
-    bool emulationRequiresPartialStores =
-        !isDivisibleInSize || *foldedNumFrontPadElems != 0;
+    // EXAMPLE 2:
+    //  Does not require a partial store:
+    //    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+    //
+    // TODO: Take linearizedInfo.linearizedOffset into account. This is
+    // currently not needed/used/exercised as all our tests set offset to 0.
+    bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
+
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6fc974200c6f3..21f073efc49b2 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
 /// vector.store
 ///----------------------------------------------------------------------------------------
 
+// -----
+
+// Most basic example to demonstrate where partial stores are not needed.
+
+func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
+    %0 = memref.alloc() : memref<13xi2>
+    %c4 = arith.constant 4 : index
+    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+    return
+}
+// CHECK-LABEL:   func.func @vector_store_i2_const_index_no_partial_store(
+// CHECK-SAME:      %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK-NOT:       memref.generic_atomic_rmw
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
+// CHECK:           %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>
+
+// -----
+
+// Small modification of the example above to demonstrate where partial stores
+// are needed.
+
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
+    %0 = memref.alloc() : memref<13xi2>
+    %c3 = arith.constant 3 : index
+    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
+    return
+}
+
+// CHECK-LABEL:   func.func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME:      %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
+
+// First atomic RMW:
+// CHECK:           %[[IDX_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK:           %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
+// CHECK:           %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
+// CHECK:           %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK:           memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: i8):
+// CHECK:             %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
+// CHECK:             %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
+// CHECK:             %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
+// CHECK:             %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
+// CHECK:             %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
+// CHECK:             memref.atomic_yield %[[RES_1]] : i8
+// CHECK:           }
+
+// Second atomic RMW:
+// CHECK:           %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:           %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
+// CHECK:           %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+// CHECK:           %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
+// CHECK:           %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK:            memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
+// CHECK:           ^bb0(%[[VAL_20:.*]]: i8):
+// CHECK:             %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
+// CHECK:             %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
+// CHECK:             %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
+// CHECK:             %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
+// CHECK:             %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
+// CHECK:             memref.atomic_yield %[[RES_2]] : i8
+// CHECK:           }
+
+// -----
+
 func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
     %src = memref.alloc() : memref<3x3xi2>
     %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9dc3eb6989c6c..9e2d131f421b7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -439,6 +439,11 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
 
 // -----
 
+// FIXME: This example assumes that the store happens at a byte boundary, but
+// that's not guaranteed. Below is a counter-example with specific dimensions:
+//    vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
+// TODO: Revisit post #136797
+
 func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
     %0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
     vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>


        


More information about the Mlir-commits mailing list