[Mlir-commits] [mlir] [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (PR #88204)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Apr 12 02:26:23 PDT 2024
================
@@ -103,6 +105,234 @@ struct LinearizeVectorizable final
return success();
}
+private:
+ unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtractStridedSlice final
+ : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorExtractStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ auto loc = extractOp.getLoc();
+ if (!dstType)
+ return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+ if (extractOp.getVector().getType().isScalable() ||
+ dstType.cast<VectorType>().isScalable())
+ return rewriter.notifyMatchFailure(loc,
+ "scalable vectors are not supported.");
+ if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+ auto offsets = extractOp.getOffsets().getValue();
+ auto sizes = extractOp.getSizes().getValue();
+ auto strides = extractOp.getStrides().getValue();
+
+ if (!isConstantIntValue(strides[0], 1))
+ return rewriter.notifyMatchFailure(
+ extractOp, "Strided slice with stride != 1 is not supported.");
+
+ Value srcVector = adaptor.getVector();
+
+ // if kD offsets are specified for nd source vector (n > k), the granularity
+ // of the extraction is greater than 1. In this case last (n-k) dimensions
+ // form the extraction granularity. example : %0 =
+ // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+ // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+ // here, extraction granularity is 8.
+ int64_t extractSliceLen = 1;
+ auto n = extractOp.getSourceVectorType().getRank();
+ auto k = (int64_t)offsets.size();
+ if (n > k) {
+ for (unsigned i = 0; i < n - k; i++) {
+ extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+ }
+ }
+
+ // get total number of extracted slices
+ int64_t nExtractedSlices = 1;
+ for (auto size : sizes) {
+ nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+ }
+
+ // compute the strides of the source vector considering first k dimensions
+ llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+ for (int i = k - 2; i >= 0; --i) {
+ sourceStrides[i] = sourceStrides[i + 1] *
+ extractOp.getSourceVectorType().getShape()[i + 1];
+ }
+ // final shuffle indices has nExtractedElems * extractSliceLen elements
+ llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+ // compute the strides of the extracted kD vector
+ llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
+ // compute extractedStrides
+ for (int i = k - 2; i >= 0; --i) {
+ extractedStrides[i] =
+ extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+ }
+ // iterate over all extracted slices from 0 to nExtractedElems-1
+ // and compute the multi-dimensional index and the corresponding linearized
+ // index within the source vector
+ for (int64_t i = 0; i < nExtractedSlices; ++i) {
+ int64_t index = i;
+ // compute the corresponding multi-dimensional index
+ llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
+ for (int64_t j = 0; j < k; ++j) {
+ multiDimIndex[j] = (index / extractedStrides[j]);
+ index -= multiDimIndex[j] * extractedStrides[j];
+ }
+ // compute the corresponding linearized index in the source vector
+ // i.e. shift the multiDimIndex by the offsets
+ int64_t linearizedIndex = 0;
+ for (int64_t j = 0; j < k; ++j) {
+ linearizedIndex +=
+ (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+ sourceStrides[j];
+ }
+ // fill the indices array form linearizedIndex to linearizedIndex +
+ // sliceLen
+ for (int64_t j = 0; j < extractSliceLen; ++j) {
+ indices[i * extractSliceLen + j] = linearizedIndex + j;
+ }
+ }
+ // perform a shuffle to extract the kD vector
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ extractOp, dstType, srcVector, srcVector,
+ rewriter.getI64ArrayAttr(indices));
+
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorShffle final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorShffle(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
+ auto loc = shuffleOp.getLoc();
+ if (!dstType)
+ return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+
+ if (shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ dstType.cast<VectorType>().isScalable())
+ return rewriter.notifyMatchFailure(loc,
+ "scalable vectors are not supported.");
+ if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+
+ auto vec1 = adaptor.getV1();
+ auto vec2 = adaptor.getV2();
+
+ int shuffleSliceLen = 1;
+ int rank = shuffleOp.getV1().getType().getRank();
+
+ // if rank > 1, we need to do the shuffle in the granularity of slices
+ // instead of scalars. Size of the slice is equal to the rank-1 innermost
+ // dims. Mask of the shuffle op specifies which slice to take from the
+ // outermost dim.
+ if (rank > 1) {
+ auto shape = shuffleOp.getV1().getType().getShape();
+ for (unsigned i = 1; i < shape.size(); i++) {
+ shuffleSliceLen *= shape[i];
+ }
+ }
+
+ auto mask = shuffleOp.getMask();
+ auto totalSize = mask.size() * shuffleSliceLen;
----------------
banach-space wrote:
`totalSize` of what?
https://github.com/llvm/llvm-project/pull/88204
More information about the Mlir-commits
mailing list