[Mlir-commits] [mlir] [mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (PR #88204)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Apr 12 02:26:23 PDT 2024


================
@@ -103,6 +105,234 @@ struct LinearizeVectorizable final
     return success();
   }
 
+private:
+  unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorExtractStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorExtractStridedSlice(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    auto loc = extractOp.getLoc();
+    if (!dstType)
+      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+    if (extractOp.getVector().getType().isScalable() ||
+        dstType.cast<VectorType>().isScalable())
+      return rewriter.notifyMatchFailure(loc,
+                                         "scalable vectors are not supported.");
+    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+
+    auto offsets = extractOp.getOffsets().getValue();
+    auto sizes = extractOp.getSizes().getValue();
+    auto strides = extractOp.getStrides().getValue();
+
+    if (!isConstantIntValue(strides[0], 1))
+      return rewriter.notifyMatchFailure(
+          extractOp, "Strided slice with stride != 1 is not supported.");
+
+    Value srcVector = adaptor.getVector();
+
+    // if kD offsets are specified for nd source vector (n > k), the granularity
+    // of the extraction is greater than 1. In this case last (n-k) dimensions
+    // form the extraction granularity. example : %0 =
+    // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
+    // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
+    // here, extraction granularity is 8.
+    int64_t extractSliceLen = 1;
+    auto n = extractOp.getSourceVectorType().getRank();
+    auto k = (int64_t)offsets.size();
+    if (n > k) {
+      for (unsigned i = 0; i < n - k; i++) {
+        extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
+      }
+    }
+
+    // get total number of extracted slices
+    int64_t nExtractedSlices = 1;
+    for (auto size : sizes) {
+      nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+    }
+
+    // compute the strides of the source vector considering first k dimensions
+    llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
+    for (int i = k - 2; i >= 0; --i) {
+      sourceStrides[i] = sourceStrides[i + 1] *
+                         extractOp.getSourceVectorType().getShape()[i + 1];
+    }
+    // final shuffle indices has nExtractedElems * extractSliceLen elements
+    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
+    // compute the strides of the extracted kD vector
+    llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
+    // compute extractedStrides
+    for (int i = k - 2; i >= 0; --i) {
+      extractedStrides[i] =
+          extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+    }
+    // iterate over all extracted slices from 0 to nExtractedElems-1
+    // and compute the multi-dimensional index and the corresponding linearized
+    // index within the source vector
+    for (int64_t i = 0; i < nExtractedSlices; ++i) {
+      int64_t index = i;
+      // compute the corresponding multi-dimensional index
+      llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
+      for (int64_t j = 0; j < k; ++j) {
+        multiDimIndex[j] = (index / extractedStrides[j]);
+        index -= multiDimIndex[j] * extractedStrides[j];
+      }
+      // compute the corresponding linearized index in the source vector
+      // i.e. shift the multiDimIndex by the offsets
+      int64_t linearizedIndex = 0;
+      for (int64_t j = 0; j < k; ++j) {
+        linearizedIndex +=
+            (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+            sourceStrides[j];
+      }
+      // fill the indices array form linearizedIndex to linearizedIndex +
+      // sliceLen
+      for (int64_t j = 0; j < extractSliceLen; ++j) {
+        indices[i * extractSliceLen + j] = linearizedIndex + j;
+      }
+    }
+    // perform a shuffle to extract the kD vector
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractOp, dstType, srcVector, srcVector,
+        rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+struct LinearizeVectorShffle final
+    : public OpConversionPattern<vector::ShuffleOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorShffle(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
+    auto loc = shuffleOp.getLoc();
+    if (!dstType)
+      return rewriter.notifyMatchFailure(loc, "cannot convert type.");
+
+    if (shuffleOp.getV1VectorType().isScalable() ||
+        shuffleOp.getV2VectorType().isScalable() ||
+        dstType.cast<VectorType>().isScalable())
+      return rewriter.notifyMatchFailure(loc,
+                                         "scalable vectors are not supported.");
----------------
banach-space wrote:

Could you add a test to capture this? You should be able to re-use linearlize.mlir - just add `-verify-diagnostics` to the RUN line. Same for the other pattern.

https://github.com/llvm/llvm-project/pull/88204


More information about the Mlir-commits mailing list