[Mlir-commits] [mlir] c577f91 - [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (#88204)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 11:13:54 PDT 2024


Author: Charitha Saumya
Date: 2024-04-18T21:13:49+03:00
New Revision: c577f91d266b74d1b5df475fa2dce7c47fc6c57e

URL: https://github.com/llvm/llvm-project/commit/c577f91d266b74d1b5df475fa2dce7c47fc6c57e
DIFF: https://github.com/llvm/llvm-project/commit/c577f91d266b74d1b5df475fa2dce7c47fc6c57e.diff

LOG: [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (#88204)

This PR adds support for converting `vector.extract_strided_slice` and
`vector.extract` operations to equivalent `vector.shuffle` operations
that operates on linearized (1-D) vectors. `vector.shuffle` operations
operating on n-D (n > 1) are also converted to equivalent shuffle
operations working on linearized vectors.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 453fa73429dd1a..fa2912a3e577d1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -389,6 +389,13 @@ void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
+/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
+/// vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
+                                                   RewritePatternSet &patterns,
+                                                   ConversionTarget &target,
+                                                   unsigned targetBitWidth);
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b59e9062e5a08e..69999f0918c103 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -13,9 +13,16 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <cstdint>
+#include <numeric>
 
 using namespace mlir;
 
@@ -103,6 +110,251 @@ struct LinearizeVectorizable final
     return success();
   }
 
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
+/// on a linearized vector.
+/// Following,
+///   vector.extract_strided_slice %source
+///         { offsets = [..], strides = [..], sizes = [..] }
+/// is converted to :
+///   %source_1d = vector.shape_cast %source
+///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the
+/// extraction.
+struct LinearizeVectorExtractStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorExtractStridedSlice(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstType = getTypeConverter()->convertType(extractOp.getType());
+    assert(!(extractOp.getVector().getType().isScalable() ||
+             dstType.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    ArrayAttr offsets = extractOp.getOffsets();
+    ArrayAttr sizes = extractOp.getSizes();
+    ArrayAttr strides = extractOp.getStrides();
+    if (!isConstantIntValue(strides[0], 1))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Strided slice with stride != 1 is not supported.");
+    Value srcVector = adaptor.getVector();
+    // If kD offsets are specified for nD source vector (n > k), the granularity
+    // of the extraction is greater than 1. In this case last (n-k) dimensions
+    // form the extraction granularity.
+    // Example :
+    //  vector.extract_strided_slice %src {
+    //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
+    //      vector<4x8x8xf32> to vector<2x2x8xf32>
+    // Here, extraction granularity is 8.
+    int64_t extractGranularitySize = 1;
+    int64_t nD = extractOp.getSourceVectorType().getRank();
+    int64_t kD = (int64_t)offsets.size();
+    int64_t k = kD;
+    while (k < nD) {
+      extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
+      ++k;
+    }
+    // Get total number of extracted slices.
+    int64_t nExtractedSlices = 1;
+    for (Attribute size : sizes) {
+      nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+    }
+    // Compute the strides of the source vector considering first k dimensions.
+    llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
+    for (int i = kD - 2; i >= 0; --i) {
+      sourceStrides[i] = sourceStrides[i + 1] *
+                         extractOp.getSourceVectorType().getShape()[i + 1];
+    }
+    // Final shuffle indices has nExtractedSlices * extractGranularitySize
+    // elements.
+    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
+                                          extractGranularitySize);
+    // Compute the strides of the extracted kD vector.
+    llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
+    // Compute extractedStrides.
+    for (int i = kD - 2; i >= 0; --i) {
+      extractedStrides[i] =
+          extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+    }
+    // Iterate over all extracted slices from 0 to nExtractedSlices - 1
+    // and compute the multi-dimensional index and the corresponding linearized
+    // index within the source vector.
+    for (int64_t i = 0; i < nExtractedSlices; ++i) {
+      int64_t index = i;
+      // Compute the corresponding multi-dimensional index.
+      llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
+      for (int64_t j = 0; j < kD; ++j) {
+        multiDimIndex[j] = (index / extractedStrides[j]);
+        index -= multiDimIndex[j] * extractedStrides[j];
+      }
+      // Compute the corresponding linearized index in the source vector
+      // i.e. shift the multiDimIndex by the offsets.
+      int64_t linearizedIndex = 0;
+      for (int64_t j = 0; j < kD; ++j) {
+        linearizedIndex +=
+            (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+            sourceStrides[j];
+      }
+      // Fill the indices array form linearizedIndex to linearizedIndex +
+      // extractGranularitySize.
+      for (int64_t j = 0; j < extractGranularitySize; ++j) {
+        indices[i * extractGranularitySize + j] = linearizedIndex + j;
+      }
+    }
+    // Perform a shuffle to extract the kD vector.
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractOp, dstType, srcVector, srcVector,
+        rewriter.getI64ArrayAttr(indices));
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the ShuffleOp that works on nD (n > 1)
+/// vectors to a ShuffleOp that works on linearized vectors.
+/// Following,
+///   vector.shuffle %v1, %v2 [ shuffle_indices ]
+/// is converted to :
+///   %v1_1d = vector.shape_cast %v1
+///   %v2_1d = vector.shape_cast %v2
+///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
+/// of the original shuffle operation.
+struct LinearizeVectorShuffle final
+    : public OpConversionPattern<vector::ShuffleOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorShuffle(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
+    assert(!(shuffleOp.getV1VectorType().isScalable() ||
+             shuffleOp.getV2VectorType().isScalable() ||
+             dstType.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    Value vec1 = adaptor.getV1();
+    Value vec2 = adaptor.getV2();
+    int shuffleSliceLen = 1;
+    int rank = shuffleOp.getV1().getType().getRank();
+
+    // If rank > 1, we need to do the shuffle in the granularity of slices
+    // instead of scalars. Size of the slice is equal to the rank-1 innermost
+    // dims. Mask of the shuffle op specifies which slice to take from the
+    // outermost dim.
+    if (rank > 1) {
+      llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
+      for (unsigned i = 1; i < shape.size(); ++i) {
+        shuffleSliceLen *= shape[i];
+      }
+    }
+
+    // For each value in the mask, we generate the indices of the source vectors
+    // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+    // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
+    // elements) instead of scalars.
+    ArrayAttr mask = shuffleOp.getMask();
+    int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
+    llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
+    for (auto [i, value] :
+         llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
+
+      int64_t v = value.getZExtValue();
+      std::iota(indices.begin() + shuffleSliceLen * i,
+                indices.begin() + shuffleSliceLen * (i + 1),
+                shuffleSliceLen * v);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the ExtractOp to a ShuffleOp that works on a
+/// linearized vector.
+/// Following,
+///   vector.extract %source [ position ]
+/// is converted to :
+///   %source_1d = vector.shape_cast %source
+///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original extract.
+struct LinearizeVectorExtract final
+    : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorExtract(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstTy = getTypeConverter()->convertType(extractOp.getType());
+    assert(!(extractOp.getVector().getType().isScalable() ||
+             dstTy.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    // Dynamic position is not supported.
+    if (extractOp.hasDynamicPosition())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "dynamic position is not supported.");
+
+    llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
+    int64_t size = extractOp.getVector().getType().getNumElements();
+
+    // Compute linearized offset.
+    int64_t linearizedOffset = 0;
+    llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
+    for (auto [i, off] : llvm::enumerate(offsets)) {
+      size /= shape[i];
+      linearizedOffset += offsets[i] * size;
+    }
+
+    llvm::SmallVector<int64_t, 2> indices(size);
+    std::iota(indices.begin(), indices.end(), linearizedOffset);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
+        rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
 private:
   unsigned targetVectorBitWidth;
 };
@@ -145,3 +397,21 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   patterns.add<LinearizeConstant, LinearizeVectorizable>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
+
+void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, unsigned int targetBitWidth) {
+  target.addDynamicallyLegalOp<vector::ShuffleOp>(
+      [=](vector::ShuffleOp shuffleOp) -> bool {
+        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+                   ? (typeConverter.isLegal(shuffleOp) &&
+                      shuffleOp.getResult()
+                              .getType()
+                              .cast<mlir::VectorType>()
+                              .getRank() == 1)
+                   : true;
+      });
+  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
+               LinearizeVectorExtractStridedSlice>(
+      typeConverter, patterns.getContext(), targetBitWidth);
+}

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 22be78cd682057..b29ceab5783d7a 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -153,3 +153,95 @@ func.func @test_0d_vector() -> vector<f32> {
   // ALL: return %[[CST]]
   return %0 : vector<f32>
 }
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_1
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+  // DEFAULT: return %[[RES]] : vector<2x2xf32
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+  // BW-128: return %[[RES]] : vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+  // BW-0: return %[[RES]] : vector<2x2xf32>
+  %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
+     : vector<4x8xf32> to vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_2
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+  // DEFAULT: return %[[RES]] : vector<1x4x2xf32>
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+  // BW-128: return %[[RES]] : vector<1x4x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
+  // BW-0: return %[[RES]] : vector<1x4x2xf32>
+  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
+    : vector<2x8x2xf32> to vector<1x4x2xf32>
+  return %0 : vector<1x4x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_shuffle
+// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
+  // DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+  // DEFAULT-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+  // BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+  // BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // BW-128: return %[[RES]] : vector<8x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+  // BW-0: return %[[RES]] : vector<8x2xf32>
+  %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+  return %0 : vector<8x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_extract
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // BW-128: return %[[RES]] : vector<8x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
+  // BW-0: return %[[RES]] : vector<8x2xf32>
+  %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+  return %0 : vector<8x2xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 00622599910567..c978699e179fca 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -867,6 +867,8 @@ struct TestVectorLinearize final
 
     vector::populateVectorLinearizeTypeConversionsAndLegality(
         typeConverter, patterns, target, targetVectorBitwidth);
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(
+        typeConverter, patterns, target, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();


        


More information about the Mlir-commits mailing list