[Mlir-commits] [mlir] [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (PR #88204)

Charitha Saumya llvmlistbot at llvm.org
Tue Apr 16 14:10:22 PDT 2024


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/88204

>From dc63b10f878bf2609bd04cc7668b238939969282 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 14:04:04 -0700
Subject: [PATCH 01/11] add linearize patterns for Extract,
 ExtractStridedSlice, Shuffle VectorOps

---
 .../Vector/Transforms/VectorLinearize.cpp     | 249 +++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       |  80 ++++++
 2 files changed, 328 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b59e9062e5a08e..257c940e5ed93c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -15,7 +15,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -103,6 +105,234 @@ struct LinearizeVectorizable final
     return success();
   }
 
+private:
+  unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtractStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorExtractStridedSlice(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    auto loc = extractOp.getLoc();
+    if (!dstType)
+      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+    if (extractOp.getVector().getType().isScalable() ||
+        dstType.cast<VectorType>().isScalable())
+      return rewriter.notifyMatchFailure(loc,
+                                         "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    auto offsets = extractOp.getOffsets().getValue();
+    auto sizes = extractOp.getSizes().getValue();
+    auto strides = extractOp.getStrides().getValue();
+
+    if (!isConstantIntValue(strides[0], 1))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Strided slice with stride != 1 is not supported.");
+
+    Value srcVector = adaptor.getVector();
+
+    // if kD offsets are specified for nd source vector (n > k), the granularity
+    // of the extraction is greater than 1. In this case last (n-k) dimensions
+    // form the extraction granularity. example : %0 =
+    // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+    // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+    // here, extraction granularity is 8.
+    int64_t extractSliceLen = 1;
+    auto n = extractOp.getSourceVectorType().getRank();
+    auto k = (int64_t)offsets.size();
+    if (n > k) {
+      for (unsigned i = 0; i < n - k; i++) {
+        extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+      }
+    }
+
+    // get total number of extracted slices
+    int64_t nExtractedSlices = 1;
+    for (auto size : sizes) {
+      nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+    }
+
+    // compute the strides of the source vector considering first k dimensions
+    llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+    for (int i = k - 2; i >= 0; --i) {
+      sourceStrides[i] = sourceStrides[i + 1] *
+                         extractOp.getSourceVectorType().getShape()[i + 1];
+    }
+    // final shuffle indices has nExtractedElems * extractSliceLen elements
+    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+    // compute the strides of the extracted kD vector
+    llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
+    // compute extractedStrides
+    for (int i = k - 2; i >= 0; --i) {
+      extractedStrides[i] =
+          extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+    }
+    // iterate over all extracted slices from 0 to nExtractedElems-1
+    // and compute the multi-dimensional index and the corresponding linearized
+    // index within the source vector
+    for (int64_t i = 0; i < nExtractedSlices; ++i) {
+      int64_t index = i;
+      // compute the corresponding multi-dimensional index
+      llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
+      for (int64_t j = 0; j < k; ++j) {
+        multiDimIndex[j] = (index / extractedStrides[j]);
+        index -= multiDimIndex[j] * extractedStrides[j];
+      }
+      // compute the corresponding linearized index in the source vector
+      // i.e. shift the multiDimIndex by the offsets
+      int64_t linearizedIndex = 0;
+      for (int64_t j = 0; j < k; ++j) {
+        linearizedIndex +=
+            (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+            sourceStrides[j];
+      }
+      // fill the indices array form linearizedIndex to linearizedIndex +
+      // sliceLen
+      for (int64_t j = 0; j < extractSliceLen; ++j) {
+        indices[i * extractSliceLen + j] = linearizedIndex + j;
+      }
+    }
+    // perform a shuffle to extract the kD vector
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractOp, dstType, srcVector, srcVector,
+        rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorShffle final
+    : public OpConversionPattern<vector::ShuffleOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorShffle(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
+    auto loc = shuffleOp.getLoc();
+    if (!dstType)
+      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+
+    if (shuffleOp.getV1VectorType().isScalable() ||
+        shuffleOp.getV2VectorType().isScalable() ||
+        dstType.cast<VectorType>().isScalable())
+      return rewriter.notifyMatchFailure(loc,
+                                         "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    auto vec1 = adaptor.getV1();
+    auto vec2 = adaptor.getV2();
+
+    int shuffleSliceLen = 1;
+    int rank = shuffleOp.getV1().getType().getRank();
+
+    // if rank > 1, we need to do the shuffle in the granularity of slices
+    // instead of scalars. Size of the slice is equal to the rank-1 innermost
+    // dims. Mask of the shuffle op specifies which slice to take from the
+    // outermost dim.
+    if (rank > 1) {
+      auto shape = shuffleOp.getV1().getType().getShape();
+      for (unsigned i = 1; i < shape.size(); i++) {
+        shuffleSliceLen *= shape[i];
+      }
+    }
+
+    auto mask = shuffleOp.getMask();
+    auto totalSize = mask.size() * shuffleSliceLen;
+
+    llvm::SmallVector<int64_t, 2> indices(totalSize);
+    for (auto [i, value] :
+         llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
+
+      int64_t v = value.getZExtValue();
+      std::iota(indices.begin() + shuffleSliceLen * i,
+                indices.begin() + shuffleSliceLen * (i + 1),
+                shuffleSliceLen * v);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtract final
+    : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorExtract(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(extractOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
+
+    if (extractOp.getVector().getType().isScalable() ||
+        dstTy.cast<VectorType>().isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    // dynamic position is not supported
+    if (extractOp.hasDynamicPosition())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "dynamic position is not supported.");
+
+    auto shape = extractOp.getVector().getType().getShape();
+    auto size = extractOp.getVector().getType().getNumElements();
+
+    // compute linearized offset
+    int64_t linearizedOffset = 0;
+    auto offsets = extractOp.getStaticPosition();
+    for (auto [i, off] : llvm::enumerate(offsets)) {
+      size /= shape[i];
+      linearizedOffset += offsets[i] * size;
+    }
+
+    llvm::SmallVector<int64_t, 2> indices(size);
+    std::iota(indices.begin(), indices.end(), linearizedOffset);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
+        rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
 private:
   unsigned targetVectorBitWidth;
 };
@@ -139,9 +369,26 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
                       ? typeConverter.isLegal(op)
                       : true);
         }
+        if (isa<vector::ShuffleOp>(op)) {
+          return (isLessThanTargetBitWidth(op, targetBitWidth)
+                      ? (typeConverter.isLegal(op) &&
+                         op->getResult(0)
+                                 .getType()
+                                 .cast<mlir::VectorType>()
+                                 .getRank() == 1)
+                      : true);
+        }
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+  // target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>(
+  //     [=](mlir::Operation *op) {
+  //       return op->getResult(0).getType().cast<mlir::VectorType>().getRank()
+  //       ==
+  //              1;
+  //     });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
+               LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 212541c79565b6..d4215a88977eb7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -164,3 +164,83 @@ func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]x
 
   return %2 : vector<2x[2]xf32>
 }
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_1
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+  // DEFAULT: return %[[RES]] : vector<2x2xf32
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+  // BW-128: return %[[RES]] : vector<2x2xf32>
+  %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
+     : vector<4x8xf32> to vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_2
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+  // DEFAULT: return %[[RES]] : vector<1x4x2xf32>
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+  // BW-128: return %[[RES]] : vector<1x4x2xf32>
+  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
+    : vector<2x8x2xf32> to vector<1x4x2xf32>
+  return %0 : vector<1x4x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_shuffle
+// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
+  // DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+  // DEFAULT: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+  // BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+  // BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // BW-128: return %[[RES]] : vector<8x2xf32>
+  %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+  return %0 : vector<8x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_extract
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
+  // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // DEFAULT: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+  // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+  // BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+  // BW-128: return %[[RES]] : vector<8x2xf32>
+  %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+  return %0 : vector<8x2xf32>
+}

>From de748c0f93e1ead19c5cc402940c9df8ab180b2d Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 14:59:06 -0700
Subject: [PATCH 02/11] remove comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 257c940e5ed93c..e5157abd245b5d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -381,13 +381,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  // target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>(
-  //     [=](mlir::Operation *op) {
-  //       return op->getResult(0).getType().cast<mlir::VectorType>().getRank()
-  //       ==
-  //              1;
-  //     });
-
   patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
                LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);

>From 962243c475e9f4b2b4fc1231edf92bc06ec12767 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 15:21:22 -0700
Subject: [PATCH 03/11] fix test

---
 mlir/test/Dialect/Vector/linearize.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 88d011e7c8594c..67f0f667a6b205 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -169,6 +169,9 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
   // BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
   // BW-128: return %[[RES]] : vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+  // BW-0: return %[[RES]] : vector<2x2xf32>
   %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
      : vector<4x8xf32> to vector<2x2xf32>
   return %0 : vector<2x2xf32>
@@ -189,6 +192,9 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
   // BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
   // BW-128: return %[[RES]] : vector<1x4x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
+  // BW-0: return %[[RES]] : vector<1x4x2xf32>
   %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
     : vector<2x8x2xf32> to vector<1x4x2xf32>
   return %0 : vector<1x4x2xf32>
@@ -211,6 +217,9 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
   // BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // BW-128: return %[[RES]] : vector<8x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+  // BW-0: return %[[RES]] : vector<8x2xf32>
   %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
   return %0 : vector<8x2xf32>
 }
@@ -230,6 +239,9 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   // BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // BW-128: return %[[RES]] : vector<8x2xf32>
+
+  // BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
+  // BW-0: return %[[RES]] : vector<8x2xf32>
   %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
   return %0 : vector<8x2xf32>
 }

>From e20be009e4b7e9abdc6acd09de44d066abcd4a30 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 12:22:24 -0700
Subject: [PATCH 04/11] address comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 34 +++++++++++--------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index e5157abd245b5d..c85f8ecf825090 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,6 +109,9 @@ struct LinearizeVectorizable final
   unsigned targetVectorBitWidth;
 };
 
+
+/// This pattern converts the vector.extract_strided_slice operation to a
+/// vector.shuffle operation that works on a linearized vector.
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -137,18 +140,16 @@ struct LinearizeVectorExtractStridedSlice final
     auto offsets = extractOp.getOffsets().getValue();
     auto sizes = extractOp.getSizes().getValue();
     auto strides = extractOp.getStrides().getValue();
-
     if (!isConstantIntValue(strides[0], 1))
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
-
     Value srcVector = adaptor.getVector();
-
     // if kD offsets are specified for nd source vector (n > k), the granularity
     // of the extraction is greater than 1. In this case last (n-k) dimensions
-    // form the extraction granularity. example : %0 =
-    // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
-    // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+    // form the extraction granularity. 
+    // example : 
+    //  %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+    //     strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
     // here, extraction granularity is 8.
     int64_t extractSliceLen = 1;
     auto n = extractOp.getSourceVectorType().getRank();
@@ -158,13 +159,11 @@ struct LinearizeVectorExtractStridedSlice final
         extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
       }
     }
-
     // get total number of extracted slices
     int64_t nExtractedSlices = 1;
     for (auto size : sizes) {
       nExtractedSlices *= size.cast<IntegerAttr>().getInt();
     }
-
     // compute the strides of the source vector considering first k dimensions
     llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
     for (int i = k - 2; i >= 0; --i) {
@@ -209,7 +208,6 @@ struct LinearizeVectorExtractStridedSlice final
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
         extractOp, dstType, srcVector, srcVector,
         rewriter.getI64ArrayAttr(indices));
-
     return success();
   }
 
@@ -217,6 +215,9 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+
+/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
+/// vectors to a vector.shuffle operation that works on linearized vectors.
 struct LinearizeVectorShffle final
     : public OpConversionPattern<vector::ShuffleOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -234,7 +235,6 @@ struct LinearizeVectorShffle final
     auto loc = shuffleOp.getLoc();
     if (!dstType)
       return rewriter.notifyMatchFailure(loc, "cannot convert type.");
-
     if (shuffleOp.getV1VectorType().isScalable() ||
         shuffleOp.getV2VectorType().isScalable() ||
         dstType.cast<VectorType>().isScalable())
@@ -246,7 +246,6 @@ struct LinearizeVectorShffle final
 
     auto vec1 = adaptor.getV1();
     auto vec2 = adaptor.getV2();
-
     int shuffleSliceLen = 1;
     int rank = shuffleOp.getV1().getType().getRank();
 
@@ -261,10 +260,13 @@ struct LinearizeVectorShffle final
       }
     }
 
+    // for each value in the mask, we generate the indices of the source vectors
+    // that needs to be shuffled to the destination vector. if shuffleSliceLen > 1
+    // we need to shuffle the slices (consecutive shuffleSliceLen number of elements) 
+    // instead of scalars.
     auto mask = shuffleOp.getMask();
-    auto totalSize = mask.size() * shuffleSliceLen;
-
-    llvm::SmallVector<int64_t, 2> indices(totalSize);
+    auto totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
+    llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
     for (auto [i, value] :
          llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
 
@@ -276,7 +278,6 @@ struct LinearizeVectorShffle final
 
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
         shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
-
     return success();
   }
 
@@ -284,6 +285,9 @@ struct LinearizeVectorShffle final
   unsigned targetVectorBitWidth;
 };
 
+
+/// This pattern converts the vector.extract operation to a vector.shuffle operation
+/// that works on a linearized vector.
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;

>From 0d9406df3f522c6a57fae495d78c798baed3c419 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 16:02:20 -0700
Subject: [PATCH 05/11] address comments

---
 .../Vector/Transforms/VectorRewritePatterns.h |  6 ++
 .../Vector/Transforms/VectorLinearize.cpp     | 80 +++++++++++--------
 .../Dialect/Vector/TestVectorTransforms.cpp   |  2 +
 3 files changed, 53 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 453fa73429dd1a..d630a3562beb94 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -389,6 +389,12 @@ void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
+/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
+/// vector shuffle operations.
+void populateVectorLinearizeToShuffleRewritePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, unsigned targetBitWidth);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c85f8ecf825090..08f6f03eb56ac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
 #include <numeric>
 
 using namespace mlir;
@@ -109,7 +110,6 @@ struct LinearizeVectorizable final
   unsigned targetVectorBitWidth;
 };
 
-
 /// This pattern converts the vector.extract_strided_slice operation to a
 /// vector.shuffle operation that works on a linearized vector.
 struct LinearizeVectorExtractStridedSlice final
@@ -144,67 +144,68 @@ struct LinearizeVectorExtractStridedSlice final
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
     Value srcVector = adaptor.getVector();
-    // if kD offsets are specified for nd source vector (n > k), the granularity
+    // If kD offsets are specified for nd source vector (n > k), the granularity
     // of the extraction is greater than 1. In this case last (n-k) dimensions
-    // form the extraction granularity. 
-    // example : 
-    //  %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+    // form the extraction granularity.
+    // example :
+    //  %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
+    //  2],
     //     strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
     // here, extraction granularity is 8.
     int64_t extractSliceLen = 1;
     auto n = extractOp.getSourceVectorType().getRank();
-    auto k = (int64_t)offsets.size();
+    int64_t k = (int64_t)offsets.size();
     if (n > k) {
       for (unsigned i = 0; i < n - k; i++) {
         extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
       }
     }
-    // get total number of extracted slices
+    // Get total number of extracted slices.
     int64_t nExtractedSlices = 1;
     for (auto size : sizes) {
       nExtractedSlices *= size.cast<IntegerAttr>().getInt();
     }
-    // compute the strides of the source vector considering first k dimensions
+    // Compute the strides of the source vector considering first k dimensions.
     llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
     for (int i = k - 2; i >= 0; --i) {
       sourceStrides[i] = sourceStrides[i + 1] *
                          extractOp.getSourceVectorType().getShape()[i + 1];
     }
-    // final shuffle indices has nExtractedElems * extractSliceLen elements
+    // Final shuffle indices has nExtractedElems * extractSliceLen elements.
     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
-    // compute the strides of the extracted kD vector
+    // Compute the strides of the extracted kD vector.
     llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
-    // compute extractedStrides
+    // Compute extractedStrides.
     for (int i = k - 2; i >= 0; --i) {
       extractedStrides[i] =
           extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
     }
-    // iterate over all extracted slices from 0 to nExtractedElems-1
+    // Iterate over all extracted slices from 0 to nExtractedElems-1
     // and compute the multi-dimensional index and the corresponding linearized
-    // index within the source vector
+    // index within the source vector.
     for (int64_t i = 0; i < nExtractedSlices; ++i) {
       int64_t index = i;
-      // compute the corresponding multi-dimensional index
+      // Compute the corresponding multi-dimensional index.
       llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
       for (int64_t j = 0; j < k; ++j) {
         multiDimIndex[j] = (index / extractedStrides[j]);
         index -= multiDimIndex[j] * extractedStrides[j];
       }
-      // compute the corresponding linearized index in the source vector
-      // i.e. shift the multiDimIndex by the offsets
+      // Compute the corresponding linearized index in the source vector
+      // i.e. shift the multiDimIndex by the offsets.
       int64_t linearizedIndex = 0;
       for (int64_t j = 0; j < k; ++j) {
         linearizedIndex +=
             (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
             sourceStrides[j];
       }
-      // fill the indices array form linearizedIndex to linearizedIndex +
-      // sliceLen
+      // Fill the indices array form linearizedIndex to linearizedIndex +
+      // sliceLen.
       for (int64_t j = 0; j < extractSliceLen; ++j) {
         indices[i * extractSliceLen + j] = linearizedIndex + j;
       }
     }
-    // perform a shuffle to extract the kD vector
+    // Perform a shuffle to extract the kD vector.
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
         extractOp, dstType, srcVector, srcVector,
         rewriter.getI64ArrayAttr(indices));
@@ -215,13 +216,12 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
-
 /// This pattern converts the vector.shuffle operation that works on nD (n > 1)
 /// vectors to a vector.shuffle operation that works on linearized vectors.
-struct LinearizeVectorShffle final
+struct LinearizeVectorShuffle final
     : public OpConversionPattern<vector::ShuffleOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorShffle(
+  LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
@@ -249,7 +249,7 @@ struct LinearizeVectorShffle final
     int shuffleSliceLen = 1;
     int rank = shuffleOp.getV1().getType().getRank();
 
-    // if rank > 1, we need to do the shuffle in the granularity of slices
+    // If rank > 1, we need to do the shuffle in the granularity of slices
     // instead of scalars. Size of the slice is equal to the rank-1 innermost
     // dims. Mask of the shuffle op specifies which slice to take from the
     // outermost dim.
@@ -260,10 +260,10 @@ struct LinearizeVectorShffle final
       }
     }
 
-    // for each value in the mask, we generate the indices of the source vectors
-    // that needs to be shuffled to the destination vector. if shuffleSliceLen > 1
-    // we need to shuffle the slices (consecutive shuffleSliceLen number of elements) 
-    // instead of scalars.
+    // For each value in the mask, we generate the indices of the source vectors
+    // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+    // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
+    // elements) instead of scalars.
     auto mask = shuffleOp.getMask();
     auto totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
@@ -285,9 +285,8 @@ struct LinearizeVectorShffle final
   unsigned targetVectorBitWidth;
 };
 
-
-/// This pattern converts the vector.extract operation to a vector.shuffle operation
-/// that works on a linearized vector.
+/// This pattern converts the vector.extract operation to a vector.shuffle
+/// operation that works on a linearized vector.
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -312,7 +311,7 @@ struct LinearizeVectorExtract final
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
-    // dynamic position is not supported
+    // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(extractOp,
                                          "dynamic position is not supported.");
@@ -320,7 +319,7 @@ struct LinearizeVectorExtract final
     auto shape = extractOp.getVector().getType().getShape();
     auto size = extractOp.getVector().getType().getNumElements();
 
-    // compute linearized offset
+    // Compute linearized offset.
     int64_t linearizedOffset = 0;
     auto offsets = extractOp.getStaticPosition();
     for (auto [i, off] : llvm::enumerate(offsets)) {
@@ -373,6 +372,18 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
                       ? typeConverter.isLegal(op)
                       : true);
         }
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+      typeConverter, patterns.getContext(), targetBitWidth);
+}
+
+void mlir::vector::populateVectorLinearizeToShuffleRewritePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, unsigned int targetBitWidth) {
+  target.markUnknownOpDynamicallyLegal(
+      [=](Operation *op) -> std::optional<bool> {
         if (isa<vector::ShuffleOp>(op)) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? (typeConverter.isLegal(op) &&
@@ -384,8 +395,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         }
         return std::nullopt;
       });
-
-  patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
-               LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
+  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
+               LinearizeVectorExtractStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 00622599910567..e29ea9ce10c68d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -867,6 +867,8 @@ struct TestVectorLinearize final
 
     vector::populateVectorLinearizeTypeConversionsAndLegality(
         typeConverter, patterns, target, targetVectorBitwidth);
+    vector::populateVectorLinearizeToShuffleRewritePatterns(
+        typeConverter, patterns, target, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

>From e807c7a09f132d2a6bb6d8205b34d708e7129e1d Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 16:14:04 -0700
Subject: [PATCH 06/11] fix tests

---
 mlir/test/Dialect/Vector/linearize.mlir | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 67f0f667a6b205..b29ceab5783d7a 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -160,13 +160,13 @@ func.func @test_0d_vector() -> vector<f32> {
 func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
   // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
   // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // DEFAULT: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // DEFAULT-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
   // DEFAULT: return %[[RES]] : vector<2x2xf32
 
   // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
   // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+  // BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
   // BW-128: return %[[RES]] : vector<2x2xf32>
 
@@ -183,13 +183,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
   // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // DEFAULT: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // DEFAULT-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
   // DEFAULT: return %[[RES]] : vector<1x4x2xf32>
 
   // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+  // BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
   // BW-128: return %[[RES]] : vector<1x4x2xf32>
 
@@ -207,14 +207,14 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
   // DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
   // DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
   // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
-  // DEFAULT: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // DEFAULT-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // DEFAULT: return %[[RES]] : vector<8x2xf32>
 
   // BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
   // BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
   // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
-  // BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // BW-128: return %[[RES]] : vector<8x2xf32>
 
@@ -230,13 +230,13 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // DEFAULT: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // DEFAULT-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // DEFAULT: return %[[RES]] : vector<8x2xf32>
 
   // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-  // BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+  // BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
   // BW-128: return %[[RES]] : vector<8x2xf32>
 

>From 92374a23a5a5d33530472f8673ba29e49c0205c5 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 11:24:13 -0700
Subject: [PATCH 07/11] fix

---
 .../Vector/Transforms/VectorLinearize.cpp     | 23 ++++++++-----------
 1 file changed, 10 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 08f6f03eb56ac8..90e8444dc4f3a4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -129,10 +129,9 @@ struct LinearizeVectorExtractStridedSlice final
     auto loc = extractOp.getLoc();
     if (!dstType)
       return rewriter.notifyMatchFailure(loc, "cannot convert type.");
-    if (extractOp.getVector().getType().isScalable() ||
-        dstType.cast<VectorType>().isScalable())
-      return rewriter.notifyMatchFailure(loc,
-                                         "scalable vectors are not supported.");
+    assert(!(extractOp.getVector().getType().isScalable() ||
+             dstType.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -235,11 +234,10 @@ struct LinearizeVectorShuffle final
     auto loc = shuffleOp.getLoc();
     if (!dstType)
       return rewriter.notifyMatchFailure(loc, "cannot convert type.");
-    if (shuffleOp.getV1VectorType().isScalable() ||
-        shuffleOp.getV2VectorType().isScalable() ||
-        dstType.cast<VectorType>().isScalable())
-      return rewriter.notifyMatchFailure(loc,
-                                         "scalable vectors are not supported.");
+    assert(!(shuffleOp.getV1VectorType().isScalable() ||
+             shuffleOp.getV2VectorType().isScalable() ||
+             dstType.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -303,10 +301,9 @@ struct LinearizeVectorExtract final
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
 
-    if (extractOp.getVector().getType().isScalable() ||
-        dstTy.cast<VectorType>().isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
+    assert(!(extractOp.getVector().getType().isScalable() ||
+             dstTy.cast<VectorType>().isScalable()) &&
+           "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");

>From c45c5076012ff13a3dd4920005f0a7d204666006 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 11:59:37 -0700
Subject: [PATCH 08/11] add comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 31 +++++++++++++++++--
 1 file changed, 28 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 90e8444dc4f3a4..b71a5c4b418f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -110,8 +110,17 @@ struct LinearizeVectorizable final
   unsigned targetVectorBitWidth;
 };
 
-/// This pattern converts the vector.extract_strided_slice operation to a
-/// vector.shuffle operation that works on a linearized vector.
+/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
+/// on a linearized vector.
+/// Following,
+///   vector.extract_strided_slice %source
+///         { offsets = [..], strides = [..], sizes = [..] }
+/// is converted to :
+///   %source_1d = vector.shape_cast %source
+///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the
+/// extraction.
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -150,7 +159,7 @@ struct LinearizeVectorExtractStridedSlice final
     //  %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
     //  2],
     //     strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
-    // here, extraction granularity is 8.
+    // Here, extraction granularity is 8.
     int64_t extractSliceLen = 1;
     auto n = extractOp.getSourceVectorType().getRank();
     int64_t k = (int64_t)offsets.size();
@@ -217,6 +226,15 @@ struct LinearizeVectorExtractStridedSlice final
 
 /// This pattern converts the vector.shuffle operation that works on nD (n > 1)
 /// vectors to a vector.shuffle operation that works on linearized vectors.
+/// Following,
+///   vector.shuffle %v1, %v2 [ shuffle_indices ]
+/// is converted to :
+///   %v1_1d = vector.shape_cast %v1
+///   %v2_1d = vector.shape_cast %v2
+///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
+/// of the original shuffle operation.
 struct LinearizeVectorShuffle final
     : public OpConversionPattern<vector::ShuffleOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -285,6 +303,13 @@ struct LinearizeVectorShuffle final
 
 /// This pattern converts the vector.extract operation to a vector.shuffle
 /// operation that works on a linearized vector.
+/// Following,
+///   vector.extract %source [ position ]
+/// is converted to :
+///   %source_1d = vector.shape_cast %source
+///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original extract.
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;

>From 1ac90fc685454a86449557aabbf76355d7d8589b Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 13:07:08 -0700
Subject: [PATCH 09/11] fix

---
 .../Vector/Transforms/VectorLinearize.cpp     | 22 +++++++++----------
 1 file changed, 10 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b71a5c4b418f38..55e917821007b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
@@ -404,18 +405,15 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 void mlir::vector::populateVectorLinearizeToShuffleRewritePatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned int targetBitWidth) {
-  target.markUnknownOpDynamicallyLegal(
-      [=](Operation *op) -> std::optional<bool> {
-        if (isa<vector::ShuffleOp>(op)) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? (typeConverter.isLegal(op) &&
-                         op->getResult(0)
-                                 .getType()
-                                 .cast<mlir::VectorType>()
-                                 .getRank() == 1)
-                      : true);
-        }
-        return std::nullopt;
+  target.addDynamicallyLegalOp<vector::ShuffleOp>(
+      [=](vector::ShuffleOp shuffleOp) -> bool {
+        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+                   ? (typeConverter.isLegal(shuffleOp) &&
+                      shuffleOp.getResult()
+                              .getType()
+                              .cast<mlir::VectorType>()
+                              .getRank() == 1)
+                   : true;
       });
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorExtractStridedSlice>(

>From 0b96134b1d7ab249879ebb8b472e123fbb8183f4 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 13:51:42 -0700
Subject: [PATCH 10/11] fix

---
 .../Vector/Transforms/VectorLinearize.cpp     | 47 +++++++++----------
 1 file changed, 21 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 55e917821007b2..b61416e5a8fd7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -136,9 +137,6 @@ struct LinearizeVectorExtractStridedSlice final
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = getTypeConverter()->convertType(extractOp.getType());
-    auto loc = extractOp.getLoc();
-    if (!dstType)
-      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
     assert(!(extractOp.getVector().getType().isScalable() ||
              dstType.cast<VectorType>().isScalable()) &&
            "scalable vectors are not supported.");
@@ -146,9 +144,9 @@ struct LinearizeVectorExtractStridedSlice final
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
-    auto offsets = extractOp.getOffsets().getValue();
-    auto sizes = extractOp.getSizes().getValue();
-    auto strides = extractOp.getStrides().getValue();
+    auto offsets = extractOp.getOffsets();
+    auto sizes = extractOp.getSizes();
+    auto strides = extractOp.getStrides();
     if (!isConstantIntValue(strides[0], 1))
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
@@ -156,32 +154,35 @@ struct LinearizeVectorExtractStridedSlice final
     // If kD offsets are specified for nd source vector (n > k), the granularity
     // of the extraction is greater than 1. In this case last (n-k) dimensions
     // form the extraction granularity.
-    // example :
-    //  %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
-    //  2],
-    //     strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+    // Example :
+    //  vector.extract_strided_slice %src {
+    //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
+    //      vector<4x8x8xf32> to vector<2x2x8xf32>
     // Here, extraction granularity is 8.
-    int64_t extractSliceLen = 1;
+    int64_t extractGranularitySize = 1;
     auto n = extractOp.getSourceVectorType().getRank();
     int64_t k = (int64_t)offsets.size();
     if (n > k) {
       for (unsigned i = 0; i < n - k; i++) {
-        extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+        extractGranularitySize *=
+            extractOp.getSourceVectorType().getShape()[i + k];
       }
     }
     // Get total number of extracted slices.
     int64_t nExtractedSlices = 1;
-    for (auto size : sizes) {
+    llvm::for_each(sizes, [&](Attribute size) {
       nExtractedSlices *= size.cast<IntegerAttr>().getInt();
-    }
+    });
     // Compute the strides of the source vector considering first k dimensions.
-    llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+    llvm::SmallVector<int64_t, 4> sourceStrides(k, extractGranularitySize);
     for (int i = k - 2; i >= 0; --i) {
       sourceStrides[i] = sourceStrides[i + 1] *
                          extractOp.getSourceVectorType().getShape()[i + 1];
     }
-    // Final shuffle indices has nExtractedElems * extractSliceLen elements.
-    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+    // Final shuffle indices has nExtractedElems * extractGranularitySize
+    // elements.
+    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
+                                          extractGranularitySize);
     // Compute the strides of the extracted kD vector.
     llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
     // Compute extractedStrides.
@@ -209,9 +210,9 @@ struct LinearizeVectorExtractStridedSlice final
             sourceStrides[j];
       }
       // Fill the indices array form linearizedIndex to linearizedIndex +
-      // sliceLen.
-      for (int64_t j = 0; j < extractSliceLen; ++j) {
-        indices[i * extractSliceLen + j] = linearizedIndex + j;
+      // extractGranularitySize.
+      for (int64_t j = 0; j < extractGranularitySize; ++j) {
+        indices[i * extractGranularitySize + j] = linearizedIndex + j;
       }
     }
     // Perform a shuffle to extract the kD vector.
@@ -250,9 +251,6 @@ struct LinearizeVectorShuffle final
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
-    auto loc = shuffleOp.getLoc();
-    if (!dstType)
-      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
     assert(!(shuffleOp.getV1VectorType().isScalable() ||
              shuffleOp.getV2VectorType().isScalable() ||
              dstType.cast<VectorType>().isScalable()) &&
@@ -324,9 +322,6 @@ struct LinearizeVectorExtract final
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstTy = getTypeConverter()->convertType(extractOp.getType());
-    if (!dstTy)
-      return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
-
     assert(!(extractOp.getVector().getType().isScalable() ||
              dstTy.cast<VectorType>().isScalable()) &&
            "scalable vectors are not supported.");

>From 7046090e5462d342140cf54c9328a63d1fcf2063 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
 <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 14:02:04 -0700
Subject: [PATCH 11/11] fix comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b61416e5a8fd7a..be5b1409aca24f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -226,8 +226,8 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
-/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
-/// vectors to a vector.shuffle operation that works on linearized vectors.
+/// This pattern converts the ShuffleOp that works on nD (n > 1)
+/// vectors to a ShuffleOp that works on linearized vectors.
 /// Following,
 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
 /// is converted to :
@@ -300,8 +300,8 @@ struct LinearizeVectorShuffle final
   unsigned targetVectorBitWidth;
 };
 
-/// This pattern converts the vector.extract operation to a vector.shuffle
-/// operation that works on a linearized vector.
+/// This pattern converts the ExtractOp to a ShuffleOp that works on a 
+/// linearized vector.
 /// Following,
 ///   vector.extract %source [ position ]
 /// is converted to :



More information about the Mlir-commits mailing list