[Mlir-commits] [mlir] [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (PR #88204)
Charitha Saumya
llvmlistbot at llvm.org
Wed Apr 17 10:20:33 PDT 2024
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/88204
>From dc63b10f878bf2609bd04cc7668b238939969282 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 14:04:04 -0700
Subject: [PATCH 01/14] add linearize patterns for Extract,
ExtractStridedSlice, Shuffle VectorOps
---
.../Vector/Transforms/VectorLinearize.cpp | 249 +++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 80 ++++++
2 files changed, 328 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b59e9062e5a08e..257c940e5ed93c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
using namespace mlir;
@@ -103,6 +105,234 @@ struct LinearizeVectorizable final
return success();
}
+private:
+ unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtractStridedSlice final
+ : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorExtractStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ auto loc = extractOp.getLoc();
+ if (!dstType)
+ return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+ if (extractOp.getVector().getType().isScalable() ||
+ dstType.cast<VectorType>().isScalable())
+ return rewriter.notifyMatchFailure(loc,
+ "scalable vectors are not supported.");
+ if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+ auto offsets = extractOp.getOffsets().getValue();
+ auto sizes = extractOp.getSizes().getValue();
+ auto strides = extractOp.getStrides().getValue();
+
+ if (!isConstantIntValue(strides[0], 1))
+ return rewriter.notifyMatchFailure(
+ extractOp, "Strided slice with stride != 1 is not supported.");
+
+ Value srcVector = adaptor.getVector();
+
+ // if kD offsets are specified for nd source vector (n > k), the granularity
+ // of the extraction is greater than 1. In this case last (n-k) dimensions
+ // form the extraction granularity. example : %0 =
+ // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+ // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+ // here, extraction granularity is 8.
+ int64_t extractSliceLen = 1;
+ auto n = extractOp.getSourceVectorType().getRank();
+ auto k = (int64_t)offsets.size();
+ if (n > k) {
+ for (unsigned i = 0; i < n - k; i++) {
+ extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+ }
+ }
+
+ // get total number of extracted slices
+ int64_t nExtractedSlices = 1;
+ for (auto size : sizes) {
+ nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+ }
+
+ // compute the strides of the source vector considering first k dimensions
+ llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+ for (int i = k - 2; i >= 0; --i) {
+ sourceStrides[i] = sourceStrides[i + 1] *
+ extractOp.getSourceVectorType().getShape()[i + 1];
+ }
+ // final shuffle indices has nExtractedElems * extractSliceLen elements
+ llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+ // compute the strides of the extracted kD vector
+ llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
+ // compute extractedStrides
+ for (int i = k - 2; i >= 0; --i) {
+ extractedStrides[i] =
+ extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+ }
+ // iterate over all extracted slices from 0 to nExtractedElems-1
+ // and compute the multi-dimensional index and the corresponding linearized
+ // index within the source vector
+ for (int64_t i = 0; i < nExtractedSlices; ++i) {
+ int64_t index = i;
+ // compute the corresponding multi-dimensional index
+ llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
+ for (int64_t j = 0; j < k; ++j) {
+ multiDimIndex[j] = (index / extractedStrides[j]);
+ index -= multiDimIndex[j] * extractedStrides[j];
+ }
+ // compute the corresponding linearized index in the source vector
+ // i.e. shift the multiDimIndex by the offsets
+ int64_t linearizedIndex = 0;
+ for (int64_t j = 0; j < k; ++j) {
+ linearizedIndex +=
+ (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+ sourceStrides[j];
+ }
+ // fill the indices array form linearizedIndex to linearizedIndex +
+ // sliceLen
+ for (int64_t j = 0; j < extractSliceLen; ++j) {
+ indices[i * extractSliceLen + j] = linearizedIndex + j;
+ }
+ }
+ // perform a shuffle to extract the kD vector
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ extractOp, dstType, srcVector, srcVector,
+ rewriter.getI64ArrayAttr(indices));
+
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorShffle final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorShffle(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
+ auto loc = shuffleOp.getLoc();
+ if (!dstType)
+ return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+
+ if (shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ dstType.cast<VectorType>().isScalable())
+ return rewriter.notifyMatchFailure(loc,
+ "scalable vectors are not supported.");
+ if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+
+ auto vec1 = adaptor.getV1();
+ auto vec2 = adaptor.getV2();
+
+ int shuffleSliceLen = 1;
+ int rank = shuffleOp.getV1().getType().getRank();
+
+ // if rank > 1, we need to do the shuffle in the granularity of slices
+ // instead of scalars. Size of the slice is equal to the rank-1 innermost
+ // dims. Mask of the shuffle op specifies which slice to take from the
+ // outermost dim.
+ if (rank > 1) {
+ auto shape = shuffleOp.getV1().getType().getShape();
+ for (unsigned i = 1; i < shape.size(); i++) {
+ shuffleSliceLen *= shape[i];
+ }
+ }
+
+ auto mask = shuffleOp.getMask();
+ auto totalSize = mask.size() * shuffleSliceLen;
+
+ llvm::SmallVector<int64_t, 2> indices(totalSize);
+ for (auto [i, value] :
+ llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
+
+ int64_t v = value.getZExtValue();
+ std::iota(indices.begin() + shuffleSliceLen * i,
+ indices.begin() + shuffleSliceLen * (i + 1),
+ shuffleSliceLen * v);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
+
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtract final
+ : public OpConversionPattern<vector::ExtractOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorExtract(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(extractOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
+
+ if (extractOp.getVector().getType().isScalable() ||
+ dstTy.cast<VectorType>().isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
+ if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+ // dynamic position is not supported
+ if (extractOp.hasDynamicPosition())
+ return rewriter.notifyMatchFailure(extractOp,
+ "dynamic position is not supported.");
+
+ auto shape = extractOp.getVector().getType().getShape();
+ auto size = extractOp.getVector().getType().getNumElements();
+
+ // compute linearized offset
+ int64_t linearizedOffset = 0;
+ auto offsets = extractOp.getStaticPosition();
+ for (auto [i, off] : llvm::enumerate(offsets)) {
+ size /= shape[i];
+ linearizedOffset += offsets[i] * size;
+ }
+
+ llvm::SmallVector<int64_t, 2> indices(size);
+ std::iota(indices.begin(), indices.end(), linearizedOffset);
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
+ rewriter.getI64ArrayAttr(indices));
+
+ return success();
+ }
+
private:
unsigned targetVectorBitWidth;
};
@@ -139,9 +369,26 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
? typeConverter.isLegal(op)
: true);
}
+ if (isa<vector::ShuffleOp>(op)) {
+ return (isLessThanTargetBitWidth(op, targetBitWidth)
+ ? (typeConverter.isLegal(op) &&
+ op->getResult(0)
+ .getType()
+ .cast<mlir::VectorType>()
+ .getRank() == 1)
+ : true);
+ }
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ // target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>(
+ // [=](mlir::Operation *op) {
+ // return op->getResult(0).getType().cast<mlir::VectorType>().getRank()
+ // ==
+ // 1;
+ // });
+
+ patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
+ LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 212541c79565b6..d4215a88977eb7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -164,3 +164,83 @@ func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]x
return %2 : vector<2x[2]xf32>
}
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_1
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+ // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // DEFAULT: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+ // DEFAULT: return %[[RES]] : vector<2x2xf32
+
+ // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
+ // BW-128: return %[[RES]] : vector<2x2xf32>
+ %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
+ : vector<4x8xf32> to vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_extract_strided_slice_2
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+ // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // DEFAULT: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+ // DEFAULT: return %[[RES]] : vector<1x4x2xf32>
+
+ // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
+ // BW-128: return %[[RES]] : vector<1x4x2xf32>
+ %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
+ : vector<2x8x2xf32> to vector<1x4x2xf32>
+ return %0 : vector<1x4x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_shuffle
+// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
+ // DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+ // DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+ // DEFAULT: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+ // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+ // BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
+ // BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
+ // BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+ // BW-128: return %[[RES]] : vector<8x2xf32>
+ %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+ return %0 : vector<8x2xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_extract
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
+func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
+ // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // DEFAULT: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+ // DEFAULT: return %[[RES]] : vector<8x2xf32>
+
+ // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+ // BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+ // BW-128: return %[[RES]] : vector<8x2xf32>
+ %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+ return %0 : vector<8x2xf32>
+}
>From de748c0f93e1ead19c5cc402940c9df8ab180b2d Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 14:59:06 -0700
Subject: [PATCH 02/14] remove comments
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 7 -------
1 file changed, 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 257c940e5ed93c..e5157abd245b5d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -381,13 +381,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- // target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>(
- // [=](mlir::Operation *op) {
- // return op->getResult(0).getType().cast<mlir::VectorType>().getRank()
- // ==
- // 1;
- // });
-
patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
>From 962243c475e9f4b2b4fc1231edf92bc06ec12767 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Apr 2024 15:21:22 -0700
Subject: [PATCH 03/14] fix test
---
mlir/test/Dialect/Vector/linearize.mlir | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 88d011e7c8594c..67f0f667a6b205 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -169,6 +169,9 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
// BW-128: return %[[RES]] : vector<2x2xf32>
+
+ // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+ // BW-0: return %[[RES]] : vector<2x2xf32>
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
: vector<4x8xf32> to vector<2x2xf32>
return %0 : vector<2x2xf32>
@@ -189,6 +192,9 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
// BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
// BW-128: return %[[RES]] : vector<1x4x2xf32>
+
+ // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
+ // BW-0: return %[[RES]] : vector<1x4x2xf32>
%0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
: vector<2x8x2xf32> to vector<1x4x2xf32>
return %0 : vector<1x4x2xf32>
@@ -211,6 +217,9 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
// BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// BW-128: return %[[RES]] : vector<8x2xf32>
+
+ // BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
+ // BW-0: return %[[RES]] : vector<8x2xf32>
%0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
return %0 : vector<8x2xf32>
}
@@ -230,6 +239,9 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
// BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// BW-128: return %[[RES]] : vector<8x2xf32>
+
+ // BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
+ // BW-0: return %[[RES]] : vector<8x2xf32>
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
return %0 : vector<8x2xf32>
}
>From e20be009e4b7e9abdc6acd09de44d066abcd4a30 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 12:22:24 -0700
Subject: [PATCH 04/14] address comments
---
.../Vector/Transforms/VectorLinearize.cpp | 34 +++++++++++--------
1 file changed, 19 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index e5157abd245b5d..c85f8ecf825090 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,6 +109,9 @@ struct LinearizeVectorizable final
unsigned targetVectorBitWidth;
};
+
+/// This pattern converts the vector.extract_strided_slice operation to a
+/// vector.shuffle operation that works on a linearized vector.
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -137,18 +140,16 @@ struct LinearizeVectorExtractStridedSlice final
auto offsets = extractOp.getOffsets().getValue();
auto sizes = extractOp.getSizes().getValue();
auto strides = extractOp.getStrides().getValue();
-
if (!isConstantIntValue(strides[0], 1))
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
-
Value srcVector = adaptor.getVector();
-
// if kD offsets are specified for nd source vector (n > k), the granularity
// of the extraction is greater than 1. In this case last (n-k) dimensions
- // form the extraction granularity. example : %0 =
- // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
- // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+ // form the extraction granularity.
+ // example :
+ // %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+ // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
// here, extraction granularity is 8.
int64_t extractSliceLen = 1;
auto n = extractOp.getSourceVectorType().getRank();
@@ -158,13 +159,11 @@ struct LinearizeVectorExtractStridedSlice final
extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
}
}
-
// get total number of extracted slices
int64_t nExtractedSlices = 1;
for (auto size : sizes) {
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
}
-
// compute the strides of the source vector considering first k dimensions
llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
for (int i = k - 2; i >= 0; --i) {
@@ -209,7 +208,6 @@ struct LinearizeVectorExtractStridedSlice final
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI64ArrayAttr(indices));
-
return success();
}
@@ -217,6 +215,9 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
+
+/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
+/// vectors to a vector.shuffle operation that works on linearized vectors.
struct LinearizeVectorShffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
@@ -234,7 +235,6 @@ struct LinearizeVectorShffle final
auto loc = shuffleOp.getLoc();
if (!dstType)
return rewriter.notifyMatchFailure(loc, "cannot convert type.");
-
if (shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
dstType.cast<VectorType>().isScalable())
@@ -246,7 +246,6 @@ struct LinearizeVectorShffle final
auto vec1 = adaptor.getV1();
auto vec2 = adaptor.getV2();
-
int shuffleSliceLen = 1;
int rank = shuffleOp.getV1().getType().getRank();
@@ -261,10 +260,13 @@ struct LinearizeVectorShffle final
}
}
+ // for each value in the mask, we generate the indices of the source vectors
+ // that needs to be shuffled to the destination vector. if shuffleSliceLen > 1
+ // we need to shuffle the slices (consecutive shuffleSliceLen number of elements)
+ // instead of scalars.
auto mask = shuffleOp.getMask();
- auto totalSize = mask.size() * shuffleSliceLen;
-
- llvm::SmallVector<int64_t, 2> indices(totalSize);
+ auto totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
+ llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
for (auto [i, value] :
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
@@ -276,7 +278,6 @@ struct LinearizeVectorShffle final
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
-
return success();
}
@@ -284,6 +285,9 @@ struct LinearizeVectorShffle final
unsigned targetVectorBitWidth;
};
+
+/// This pattern converts the vector.extract operation to a vector.shuffle operation
+/// that works on a linearized vector.
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
>From 0d9406df3f522c6a57fae495d78c798baed3c419 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 16:02:20 -0700
Subject: [PATCH 05/14] address comments
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 ++
.../Vector/Transforms/VectorLinearize.cpp | 80 +++++++++++--------
.../Dialect/Vector/TestVectorTransforms.cpp | 2 +
3 files changed, 53 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 453fa73429dd1a..d630a3562beb94 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -389,6 +389,12 @@ void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
+/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
+/// vector shuffle operations.
+void populateVectorLinearizeToShuffleRewritePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, unsigned targetBitWidth);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c85f8ecf825090..08f6f03eb56ac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
#include <numeric>
using namespace mlir;
@@ -109,7 +110,6 @@ struct LinearizeVectorizable final
unsigned targetVectorBitWidth;
};
-
/// This pattern converts the vector.extract_strided_slice operation to a
/// vector.shuffle operation that works on a linearized vector.
struct LinearizeVectorExtractStridedSlice final
@@ -144,67 +144,68 @@ struct LinearizeVectorExtractStridedSlice final
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
Value srcVector = adaptor.getVector();
- // if kD offsets are specified for nd source vector (n > k), the granularity
+ // If kD offsets are specified for nd source vector (n > k), the granularity
// of the extraction is greater than 1. In this case last (n-k) dimensions
- // form the extraction granularity.
- // example :
- // %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+ // form the extraction granularity.
+ // example :
+ // %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
+ // 2],
// strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
// here, extraction granularity is 8.
int64_t extractSliceLen = 1;
auto n = extractOp.getSourceVectorType().getRank();
- auto k = (int64_t)offsets.size();
+ int64_t k = (int64_t)offsets.size();
if (n > k) {
for (unsigned i = 0; i < n - k; i++) {
extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
}
}
- // get total number of extracted slices
+ // Get total number of extracted slices.
int64_t nExtractedSlices = 1;
for (auto size : sizes) {
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
}
- // compute the strides of the source vector considering first k dimensions
+ // Compute the strides of the source vector considering first k dimensions.
llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
for (int i = k - 2; i >= 0; --i) {
sourceStrides[i] = sourceStrides[i + 1] *
extractOp.getSourceVectorType().getShape()[i + 1];
}
- // final shuffle indices has nExtractedElems * extractSliceLen elements
+ // Final shuffle indices has nExtractedElems * extractSliceLen elements.
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
- // compute the strides of the extracted kD vector
+ // Compute the strides of the extracted kD vector.
llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
- // compute extractedStrides
+ // Compute extractedStrides.
for (int i = k - 2; i >= 0; --i) {
extractedStrides[i] =
extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
}
- // iterate over all extracted slices from 0 to nExtractedElems-1
+ // Iterate over all extracted slices from 0 to nExtractedElems-1
// and compute the multi-dimensional index and the corresponding linearized
- // index within the source vector
+ // index within the source vector.
for (int64_t i = 0; i < nExtractedSlices; ++i) {
int64_t index = i;
- // compute the corresponding multi-dimensional index
+ // Compute the corresponding multi-dimensional index.
llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
for (int64_t j = 0; j < k; ++j) {
multiDimIndex[j] = (index / extractedStrides[j]);
index -= multiDimIndex[j] * extractedStrides[j];
}
- // compute the corresponding linearized index in the source vector
- // i.e. shift the multiDimIndex by the offsets
+ // Compute the corresponding linearized index in the source vector
+ // i.e. shift the multiDimIndex by the offsets.
int64_t linearizedIndex = 0;
for (int64_t j = 0; j < k; ++j) {
linearizedIndex +=
(offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
sourceStrides[j];
}
- // fill the indices array form linearizedIndex to linearizedIndex +
- // sliceLen
+ // Fill the indices array form linearizedIndex to linearizedIndex +
+ // sliceLen.
for (int64_t j = 0; j < extractSliceLen; ++j) {
indices[i * extractSliceLen + j] = linearizedIndex + j;
}
}
- // perform a shuffle to extract the kD vector
+ // Perform a shuffle to extract the kD vector.
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI64ArrayAttr(indices));
@@ -215,13 +216,12 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
-
/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
/// vectors to a vector.shuffle operation that works on linearized vectors.
-struct LinearizeVectorShffle final
+struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorShffle(
+ LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
@@ -249,7 +249,7 @@ struct LinearizeVectorShffle final
int shuffleSliceLen = 1;
int rank = shuffleOp.getV1().getType().getRank();
- // if rank > 1, we need to do the shuffle in the granularity of slices
+ // If rank > 1, we need to do the shuffle in the granularity of slices
// instead of scalars. Size of the slice is equal to the rank-1 innermost
// dims. Mask of the shuffle op specifies which slice to take from the
// outermost dim.
@@ -260,10 +260,10 @@ struct LinearizeVectorShffle final
}
}
- // for each value in the mask, we generate the indices of the source vectors
- // that needs to be shuffled to the destination vector. if shuffleSliceLen > 1
- // we need to shuffle the slices (consecutive shuffleSliceLen number of elements)
- // instead of scalars.
+ // For each value in the mask, we generate the indices of the source vectors
+ // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+ // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
+ // elements) instead of scalars.
auto mask = shuffleOp.getMask();
auto totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
@@ -285,9 +285,8 @@ struct LinearizeVectorShffle final
unsigned targetVectorBitWidth;
};
-
-/// This pattern converts the vector.extract operation to a vector.shuffle operation
-/// that works on a linearized vector.
+/// This pattern converts the vector.extract operation to a vector.shuffle
+/// operation that works on a linearized vector.
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -312,7 +311,7 @@ struct LinearizeVectorExtract final
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
- // dynamic position is not supported
+ // Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
return rewriter.notifyMatchFailure(extractOp,
"dynamic position is not supported.");
@@ -320,7 +319,7 @@ struct LinearizeVectorExtract final
auto shape = extractOp.getVector().getType().getShape();
auto size = extractOp.getVector().getType().getNumElements();
- // compute linearized offset
+ // Compute linearized offset.
int64_t linearizedOffset = 0;
auto offsets = extractOp.getStaticPosition();
for (auto [i, off] : llvm::enumerate(offsets)) {
@@ -373,6 +372,18 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
? typeConverter.isLegal(op)
: true);
}
+ return std::nullopt;
+ });
+
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ typeConverter, patterns.getContext(), targetBitWidth);
+}
+
+void mlir::vector::populateVectorLinearizeToShuffleRewritePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, unsigned int targetBitWidth) {
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
if (isa<vector::ShuffleOp>(op)) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? (typeConverter.isLegal(op) &&
@@ -384,8 +395,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
}
return std::nullopt;
});
-
- patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorShffle,
- LinearizeVectorExtract, LinearizeVectorExtractStridedSlice>(
+ patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
+ LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 00622599910567..e29ea9ce10c68d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -867,6 +867,8 @@ struct TestVectorLinearize final
vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter, patterns, target, targetVectorBitwidth);
+ vector::populateVectorLinearizeToShuffleRewritePatterns(
+ typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
>From e807c7a09f132d2a6bb6d8205b34d708e7129e1d Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 12 Apr 2024 16:14:04 -0700
Subject: [PATCH 06/14] fix tests
---
mlir/test/Dialect/Vector/linearize.mlir | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 67f0f667a6b205..b29ceab5783d7a 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -160,13 +160,13 @@ func.func @test_0d_vector() -> vector<f32> {
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // DEFAULT: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+ // DEFAULT-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
// DEFAULT: return %[[RES]] : vector<2x2xf32
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // BW-128: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
+ // BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
// BW-128: return %[[RES]] : vector<2x2xf32>
@@ -183,13 +183,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // DEFAULT: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+ // DEFAULT-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
// DEFAULT: return %[[RES]] : vector<1x4x2xf32>
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // BW-128: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
+ // BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
// BW-128: return %[[RES]] : vector<1x4x2xf32>
@@ -207,14 +207,14 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
// DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
// DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
- // DEFAULT: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // DEFAULT-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// DEFAULT: return %[[RES]] : vector<8x2xf32>
// BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
// BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
- // BW-128: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// BW-128: return %[[RES]] : vector<8x2xf32>
@@ -230,13 +230,13 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // DEFAULT: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+ // DEFAULT-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// DEFAULT: return %[[RES]] : vector<8x2xf32>
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
- // BW-128: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
+ // BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// BW-128: return %[[RES]] : vector<8x2xf32>
>From 92374a23a5a5d33530472f8673ba29e49c0205c5 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 11:24:13 -0700
Subject: [PATCH 07/14] fix
---
.../Vector/Transforms/VectorLinearize.cpp | 23 ++++++++-----------
1 file changed, 10 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 08f6f03eb56ac8..90e8444dc4f3a4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -129,10 +129,9 @@ struct LinearizeVectorExtractStridedSlice final
auto loc = extractOp.getLoc();
if (!dstType)
return rewriter.notifyMatchFailure(loc, "cannot convert type.");
- if (extractOp.getVector().getType().isScalable() ||
- dstType.cast<VectorType>().isScalable())
- return rewriter.notifyMatchFailure(loc,
- "scalable vectors are not supported.");
+ assert(!(extractOp.getVector().getType().isScalable() ||
+ dstType.cast<VectorType>().isScalable()) &&
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -235,11 +234,10 @@ struct LinearizeVectorShuffle final
auto loc = shuffleOp.getLoc();
if (!dstType)
return rewriter.notifyMatchFailure(loc, "cannot convert type.");
- if (shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- dstType.cast<VectorType>().isScalable())
- return rewriter.notifyMatchFailure(loc,
- "scalable vectors are not supported.");
+ assert(!(shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ dstType.cast<VectorType>().isScalable()) &&
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -303,10 +301,9 @@ struct LinearizeVectorExtract final
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
- if (extractOp.getVector().getType().isScalable() ||
- dstTy.cast<VectorType>().isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
+ assert(!(extractOp.getVector().getType().isScalable() ||
+ dstTy.cast<VectorType>().isScalable()) &&
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
>From c45c5076012ff13a3dd4920005f0a7d204666006 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 11:59:37 -0700
Subject: [PATCH 08/14] add comments
---
.../Vector/Transforms/VectorLinearize.cpp | 31 +++++++++++++++++--
1 file changed, 28 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 90e8444dc4f3a4..b71a5c4b418f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -110,8 +110,17 @@ struct LinearizeVectorizable final
unsigned targetVectorBitWidth;
};
-/// This pattern converts the vector.extract_strided_slice operation to a
-/// vector.shuffle operation that works on a linearized vector.
+/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
+/// on a linearized vector.
+/// Following,
+/// vector.extract_strided_slice %source
+/// { offsets = [..], strides = [..], sizes = [..] }
+/// is converted to :
+/// %source_1d = vector.shape_cast %source
+/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+/// %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the
+/// extraction.
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -150,7 +159,7 @@ struct LinearizeVectorExtractStridedSlice final
// %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
// 2],
// strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
- // here, extraction granularity is 8.
+ // Here, extraction granularity is 8.
int64_t extractSliceLen = 1;
auto n = extractOp.getSourceVectorType().getRank();
int64_t k = (int64_t)offsets.size();
@@ -217,6 +226,15 @@ struct LinearizeVectorExtractStridedSlice final
/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
/// vectors to a vector.shuffle operation that works on linearized vectors.
+/// Following,
+/// vector.shuffle %v1, %v2 [ shuffle_indices ]
+/// is converted to :
+/// %v1_1d = vector.shape_cast %v1
+/// %v2_1d = vector.shape_cast %v2
+/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+/// %out_nd = vector.shape_cast %out_1d
+// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
+/// of the original shuffle operation.
struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
@@ -285,6 +303,13 @@ struct LinearizeVectorShuffle final
/// This pattern converts the vector.extract operation to a vector.shuffle
/// operation that works on a linearized vector.
+/// Following,
+/// vector.extract %source [ position ]
+/// is converted to :
+/// %source_1d = vector.shape_cast %source
+/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+/// %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original extract.
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
>From 1ac90fc685454a86449557aabbf76355d7d8589b Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 13:07:08 -0700
Subject: [PATCH 09/14] fix
---
.../Vector/Transforms/VectorLinearize.cpp | 22 +++++++++----------
1 file changed, 10 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b71a5c4b418f38..55e917821007b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
@@ -404,18 +405,15 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
void mlir::vector::populateVectorLinearizeToShuffleRewritePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned int targetBitWidth) {
- target.markUnknownOpDynamicallyLegal(
- [=](Operation *op) -> std::optional<bool> {
- if (isa<vector::ShuffleOp>(op)) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
- ? (typeConverter.isLegal(op) &&
- op->getResult(0)
- .getType()
- .cast<mlir::VectorType>()
- .getRank() == 1)
- : true);
- }
- return std::nullopt;
+ target.addDynamicallyLegalOp<vector::ShuffleOp>(
+ [=](vector::ShuffleOp shuffleOp) -> bool {
+ return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+ ? (typeConverter.isLegal(shuffleOp) &&
+ shuffleOp.getResult()
+ .getType()
+ .cast<mlir::VectorType>()
+ .getRank() == 1)
+ : true;
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorExtractStridedSlice>(
>From 0b96134b1d7ab249879ebb8b472e123fbb8183f4 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 13:51:42 -0700
Subject: [PATCH 10/14] fix
---
.../Vector/Transforms/VectorLinearize.cpp | 47 +++++++++----------
1 file changed, 21 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 55e917821007b2..b61416e5a8fd7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -136,9 +137,6 @@ struct LinearizeVectorExtractStridedSlice final
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(extractOp.getType());
- auto loc = extractOp.getLoc();
- if (!dstType)
- return rewriter.notifyMatchFailure(loc, "cannot convert type.");
assert(!(extractOp.getVector().getType().isScalable() ||
dstType.cast<VectorType>().isScalable()) &&
"scalable vectors are not supported.");
@@ -146,9 +144,9 @@ struct LinearizeVectorExtractStridedSlice final
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
- auto offsets = extractOp.getOffsets().getValue();
- auto sizes = extractOp.getSizes().getValue();
- auto strides = extractOp.getStrides().getValue();
+ auto offsets = extractOp.getOffsets();
+ auto sizes = extractOp.getSizes();
+ auto strides = extractOp.getStrides();
if (!isConstantIntValue(strides[0], 1))
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
@@ -156,32 +154,35 @@ struct LinearizeVectorExtractStridedSlice final
// If kD offsets are specified for nd source vector (n > k), the granularity
// of the extraction is greater than 1. In this case last (n-k) dimensions
// form the extraction granularity.
- // example :
- // %0 = vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2,
- // 2],
- // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+ // Example :
+ // vector.extract_strided_slice %src {
+ // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
+ // vector<4x8x8xf32> to vector<2x2x8xf32>
// Here, extraction granularity is 8.
- int64_t extractSliceLen = 1;
+ int64_t extractGranularitySize = 1;
auto n = extractOp.getSourceVectorType().getRank();
int64_t k = (int64_t)offsets.size();
if (n > k) {
for (unsigned i = 0; i < n - k; i++) {
- extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+ extractGranularitySize *=
+ extractOp.getSourceVectorType().getShape()[i + k];
}
}
// Get total number of extracted slices.
int64_t nExtractedSlices = 1;
- for (auto size : sizes) {
+ llvm::for_each(sizes, [&](Attribute size) {
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
- }
+ });
// Compute the strides of the source vector considering first k dimensions.
- llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+ llvm::SmallVector<int64_t, 4> sourceStrides(k, extractGranularitySize);
for (int i = k - 2; i >= 0; --i) {
sourceStrides[i] = sourceStrides[i + 1] *
extractOp.getSourceVectorType().getShape()[i + 1];
}
- // Final shuffle indices has nExtractedElems * extractSliceLen elements.
- llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+ // Final shuffle indices has nExtractedElems * extractGranularitySize
+ // elements.
+ llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
+ extractGranularitySize);
// Compute the strides of the extracted kD vector.
llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
// Compute extractedStrides.
@@ -209,9 +210,9 @@ struct LinearizeVectorExtractStridedSlice final
sourceStrides[j];
}
// Fill the indices array form linearizedIndex to linearizedIndex +
- // sliceLen.
- for (int64_t j = 0; j < extractSliceLen; ++j) {
- indices[i * extractSliceLen + j] = linearizedIndex + j;
+ // extractGranularitySize.
+ for (int64_t j = 0; j < extractGranularitySize; ++j) {
+ indices[i * extractGranularitySize + j] = linearizedIndex + j;
}
}
// Perform a shuffle to extract the kD vector.
@@ -250,9 +251,6 @@ struct LinearizeVectorShuffle final
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
- auto loc = shuffleOp.getLoc();
- if (!dstType)
- return rewriter.notifyMatchFailure(loc, "cannot convert type.");
assert(!(shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
dstType.cast<VectorType>().isScalable()) &&
@@ -324,9 +322,6 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstTy = getTypeConverter()->convertType(extractOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
-
assert(!(extractOp.getVector().getType().isScalable() ||
dstTy.cast<VectorType>().isScalable()) &&
"scalable vectors are not supported.");
>From 7046090e5462d342140cf54c9328a63d1fcf2063 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 14:02:04 -0700
Subject: [PATCH 11/14] fix comments
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b61416e5a8fd7a..be5b1409aca24f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -226,8 +226,8 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
-/// This pattern converts the vector.shuffle operation that works on nD (n > 1)
-/// vectors to a vector.shuffle operation that works on linearized vectors.
+/// This pattern converts the ShuffleOp that works on nD (n > 1)
+/// vectors to a ShuffleOp that works on linearized vectors.
/// Following,
/// vector.shuffle %v1, %v2 [ shuffle_indices ]
/// is converted to :
@@ -300,8 +300,8 @@ struct LinearizeVectorShuffle final
unsigned targetVectorBitWidth;
};
-/// This pattern converts the vector.extract operation to a vector.shuffle
-/// operation that works on a linearized vector.
+/// This pattern converts the ExtractOp to a ShuffleOp that works on a
+/// linearized vector.
/// Following,
/// vector.extract %source [ position ]
/// is converted to :
>From 780896e5bc61e485a6b23d96ea75bf365cd3523b Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 14:32:57 -0700
Subject: [PATCH 12/14] fix comments
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index be5b1409aca24f..1c00d14b3309dc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -300,7 +300,7 @@ struct LinearizeVectorShuffle final
unsigned targetVectorBitWidth;
};
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
+/// This pattern converts the ExtractOp to a ShuffleOp that works on a
/// linearized vector.
/// Following,
/// vector.extract %source [ position ]
>From ec369719dda443d2190e440f39ff7f6b2faea8b9 Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 16 Apr 2024 15:13:48 -0700
Subject: [PATCH 13/14] fix
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 1c00d14b3309dc..0cb7906fb0eb74 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -179,7 +179,7 @@ struct LinearizeVectorExtractStridedSlice final
sourceStrides[i] = sourceStrides[i + 1] *
extractOp.getSourceVectorType().getShape()[i + 1];
}
- // Final shuffle indices has nExtractedElems * extractGranularitySize
+ // Final shuffle indices has nExtractedSlices * extractGranularitySize
// elements.
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
extractGranularitySize);
@@ -190,7 +190,7 @@ struct LinearizeVectorExtractStridedSlice final
extractedStrides[i] =
extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
}
- // Iterate over all extracted slices from 0 to nExtractedElems-1
+ // Iterate over all extracted slices from 0 to nExtractedSlices - 1
// and compute the multi-dimensional index and the corresponding linearized
// index within the source vector.
for (int64_t i = 0; i < nExtractedSlices; ++i) {
>From 32f39da546058ddb2f7a29feb3ad599fcdf7927c Mon Sep 17 00:00:00 2001
From: "Gusthinna Waduge, Charitha Saumya"
<charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 17 Apr 2024 10:12:08 -0700
Subject: [PATCH 14/14] address comments
---
.../Vector/Transforms/VectorRewritePatterns.h | 7 ++--
.../Vector/Transforms/VectorLinearize.cpp | 42 ++++++++++---------
.../Dialect/Vector/TestVectorTransforms.cpp | 2 +-
3 files changed, 27 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index d630a3562beb94..fa2912a3e577d1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -391,9 +391,10 @@ void populateVectorLinearizeTypeConversionsAndLegality(
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
-void populateVectorLinearizeToShuffleRewritePatterns(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns,
+ ConversionTarget &target,
+ unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 0cb7906fb0eb74..01b3491f98baba 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -14,11 +14,13 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
#include <numeric>
@@ -136,7 +138,7 @@ struct LinearizeVectorExtractStridedSlice final
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ Type dstType = getTypeConverter()->convertType(extractOp.getType());
assert(!(extractOp.getVector().getType().isScalable() ||
dstType.cast<VectorType>().isScalable()) &&
"scalable vectors are not supported.");
@@ -144,9 +146,9 @@ struct LinearizeVectorExtractStridedSlice final
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
- auto offsets = extractOp.getOffsets();
- auto sizes = extractOp.getSizes();
- auto strides = extractOp.getStrides();
+ ArrayAttr offsets = extractOp.getOffsets();
+ ArrayAttr sizes = extractOp.getSizes();
+ ArrayAttr strides = extractOp.getStrides();
if (!isConstantIntValue(strides[0], 1))
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
@@ -160,19 +162,19 @@ struct LinearizeVectorExtractStridedSlice final
// vector<4x8x8xf32> to vector<2x2x8xf32>
// Here, extraction granularity is 8.
int64_t extractGranularitySize = 1;
- auto n = extractOp.getSourceVectorType().getRank();
+ int64_t n = extractOp.getSourceVectorType().getRank();
int64_t k = (int64_t)offsets.size();
if (n > k) {
- for (unsigned i = 0; i < n - k; i++) {
+ for (unsigned i = 0; i < n - k; ++i) {
extractGranularitySize *=
extractOp.getSourceVectorType().getShape()[i + k];
}
}
// Get total number of extracted slices.
int64_t nExtractedSlices = 1;
- llvm::for_each(sizes, [&](Attribute size) {
+ for (Attribute size : sizes) {
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
- });
+ }
// Compute the strides of the source vector considering first k dimensions.
llvm::SmallVector<int64_t, 4> sourceStrides(k, extractGranularitySize);
for (int i = k - 2; i >= 0; --i) {
@@ -250,7 +252,7 @@ struct LinearizeVectorShuffle final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
+ Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
assert(!(shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
dstType.cast<VectorType>().isScalable()) &&
@@ -259,8 +261,8 @@ struct LinearizeVectorShuffle final
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
- auto vec1 = adaptor.getV1();
- auto vec2 = adaptor.getV2();
+ Value vec1 = adaptor.getV1();
+ Value vec2 = adaptor.getV2();
int shuffleSliceLen = 1;
int rank = shuffleOp.getV1().getType().getRank();
@@ -269,8 +271,8 @@ struct LinearizeVectorShuffle final
// dims. Mask of the shuffle op specifies which slice to take from the
// outermost dim.
if (rank > 1) {
- auto shape = shuffleOp.getV1().getType().getShape();
- for (unsigned i = 1; i < shape.size(); i++) {
+ llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
+ for (unsigned i = 1; i < shape.size(); ++i) {
shuffleSliceLen *= shape[i];
}
}
@@ -279,8 +281,8 @@ struct LinearizeVectorShuffle final
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
- auto mask = shuffleOp.getMask();
- auto totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
+ ArrayAttr mask = shuffleOp.getMask();
+ int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
for (auto [i, value] :
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
@@ -321,7 +323,7 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstTy = getTypeConverter()->convertType(extractOp.getType());
+ Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(!(extractOp.getVector().getType().isScalable() ||
dstTy.cast<VectorType>().isScalable()) &&
"scalable vectors are not supported.");
@@ -334,12 +336,12 @@ struct LinearizeVectorExtract final
return rewriter.notifyMatchFailure(extractOp,
"dynamic position is not supported.");
- auto shape = extractOp.getVector().getType().getShape();
- auto size = extractOp.getVector().getType().getNumElements();
+ llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
+ int64_t size = extractOp.getVector().getType().getNumElements();
// Compute linearized offset.
int64_t linearizedOffset = 0;
- auto offsets = extractOp.getStaticPosition();
+ llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
for (auto [i, off] : llvm::enumerate(offsets)) {
size /= shape[i];
linearizedOffset += offsets[i] * size;
@@ -397,7 +399,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter, patterns.getContext(), targetBitWidth);
}
-void mlir::vector::populateVectorLinearizeToShuffleRewritePatterns(
+void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e29ea9ce10c68d..c978699e179fca 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -867,7 +867,7 @@ struct TestVectorLinearize final
vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter, patterns, target, targetVectorBitwidth);
- vector::populateVectorLinearizeToShuffleRewritePatterns(
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(
typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
More information about the Mlir-commits
mailing list