[Mlir-commits] [mlir] [mlir][vector] Add linearization patterns for vector.load & vector.store (PR #137558)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 27 15:41:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>

This PR is a breakdown [1 / 4] of the PR #<!-- -->136193 
The PR adds linearization patterns for vector.load and vector.store ops. The current patterns only supports 2D vectors for now. 

---

Patch is 20.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137558.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+149-11) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+110) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..e4d88de2cf4ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,7 +26,12 @@
 
 using namespace mlir;
 
+constexpr unsigned defaultTargetVectorBitWidth =
+    std::numeric_limits<unsigned>::max();
+
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  if (targetBitWidth == 0)
+    return false;
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -82,7 +87,7 @@ struct LinearizeConstantLike final
 
   LinearizeConstantLike(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +141,7 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +180,7 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +294,7 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +367,17 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +434,7 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +515,7 @@ struct LinearizeVectorBitCast final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorBitCast(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +540,139 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+// clang-format off
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+// clang-format on
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    VectorType vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2)
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    VectorType newVecType =
+        VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+                                                 adaptor.getBase(), indices);
+      resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    VectorType vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2)
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+                                                    adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                       indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -555,9 +691,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+  target.addLegalOp<vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -567,9 +704,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorLoad, LinearizeVectorStore>(
+          typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..9e793c5dc8233 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,113 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: linearize_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @linearize_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+  // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+  // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+  // DEFAULT: return %[[CAST]] : vector<4x4xf16>
+
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+  // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+  // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+  // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[C3:.*]] = arith.constant 3 : index
+  // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: return %[[CAST]] : vector<4x4xf16>
+
+  // BW-0: %[[C1:.*]] = arith.constant 1 : index
+  // BW-0: %[[C2:.*]] = arith.constant 2 : index
+  // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+  // BW-0: return %[[LOAD]] : vector<4x4xf16>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// -----
+// ALL-LABEL: linearize_vector_store
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) {
+func.func @linearize_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+  // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+  // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+  // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: return
+
+  // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+  // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+  // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+  // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+  // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[C3:.*]] = arith.constant 3 : index
+  // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: return
+
+  // BW-0: %[[C1:.*]] = arith.constant 1 : index
+  // BW-0: %[[C2:.*]] = arith.constant 2 : index
+  // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+  // BW-0: return
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03f907e46c2c6..14c7e9d554cd9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -852,7 +852,8 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/137558


More information about the Mlir-commits mailing list