[Mlir-commits] [mlir] eda2ebd - [mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 22 03:03:51 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-22T10:03:33Z
New Revision: eda2ebd7807376829eb880c39623f364b438971f

URL: https://github.com/llvm/llvm-project/commit/eda2ebd7807376829eb880c39623f364b438971f
DIFF: https://github.com/llvm/llvm-project/commit/eda2ebd7807376829eb880c39623f364b438971f.diff

LOG: [mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.

Differential Revision: https://reviews.llvm.org/D112301

Added: 
    mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
new file mode 100644
index 0000000000000..13b310713f7b5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -0,0 +1,58 @@
+//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
+#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace vector {
+
+/// Populate `patterns` with the following patterns.
+///
+/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
+/// =======================================================
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have 
diff erent ranks.
+///
+/// When ranks are 
diff erent, InsertStridedSlice needs to extract a properly
+/// ranked vector from the destination vector into which to insert. This pattern
+/// only takes care of this extraction part and forwards the rest to
+/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+///
+/// For a k-D source and n-D destination vector (k < n), we emit:
+///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
+///      insert the k-D source.
+///   2. k-D -> (n-1)-D InsertStridedSlice op
+///   3. InsertOp that is the reverse of 1.
+///
+/// [VectorInsertStridedSliceOpSameRankRewritePattern]
+/// ==================================================
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have the same rank. For each outermost index in the slice:
+///   begin    end             stride
+/// [offset : offset+size*stride : stride]
+///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
+///   2. InsertStridedSlice (k-1)-D into (n-1)-D
+///   3. the destination subvector is inserted back in the proper place
+///   3. InsertOp that is the reverse of 1.
+///
+/// [VectorExtractStridedSliceOpRewritePattern]
+/// ===========================================
+/// Progressive lowering of ExtractStridedSliceOp to either:
+///   1. single offset extract as a direct vector::ShuffleOp.
+///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
+///      InsertOp/InsertElementOp for the n-D case.
+void populateVectorInsertExtractStridedSliceTransforms(
+    RewritePatternSet &patterns);
+
+} // namespace vector
+} // namespace mlir
+
+#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 59e6ac07bbca3..d26636c132ac4 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -24,13 +24,6 @@ namespace scf {
 class IfOp;
 } // namespace scf
 
-/// Collect a set of patterns to convert from the Vector dialect to itself.
-/// Should be merged with populateVectorToSCFLoweringPattern.
-void populateVectorToVectorConversionPatterns(
-    MLIRContext *context, RewritePatternSet &patterns,
-    ArrayRef<int64_t> coarseVectorShape = {},
-    ArrayRef<int64_t> fineVectorShape = {});
-
 namespace vector {
 
 /// Options that control the vector unrolling.

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 2df0229ec2f4a..788e0c8316e23 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
 #include "llvm/ADT/DenseMap.h"
@@ -184,6 +185,11 @@ bool checkSameValueRAW(vector::TransferWriteOp defWrite,
 bool checkSameValueWAW(vector::TransferWriteOp write,
                        vector::TransferWriteOp priorWrite);
 
+// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+                                       unsigned dropFront = 0,
+                                       unsigned dropBack = 0);
+
 namespace matcher {
 
 /// Matches vector.transfer_read, vector.transfer_write and ops that return a

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a6f25332d1331..77d2a46977172 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
                                               rewriter.getI64ArrayAttr(pos));
 }
 
-// Helper that picks the proper sequence for inserting.
-static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
-                       Value into, int64_t offset) {
-  auto vectorType = into.getType().cast<VectorType>();
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(ConversionPatternRewriter &rewriter,
                         LLVMTypeConverter &typeConverter, Location loc,
@@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
                                                rewriter.getI64ArrayAttr(pos));
 }
 
-// Helper that picks the proper sequence for extracting.
-static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
-                        int64_t offset) {
-  auto vectorType = vector.getType().cast<VectorType>();
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
-// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-// TODO: Better support for attribute subtype forwarding + slicing.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
-                                              unsigned dropFront = 0,
-                                              unsigned dropBack = 0) {
-  assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
-  auto range = arrayAttr.getAsRange<IntegerAttr>();
-  SmallVector<int64_t, 4> res;
-  res.reserve(arrayAttr.size() - dropFront - dropBack);
-  for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
-       it != eit; ++it)
-    res.push_back((*it).getValue().getSExtValue());
-  return res;
-}
-
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
@@ -813,132 +777,6 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
   }
 };
 
-// When ranks are 
diff erent, InsertStridedSlice needs to extract a properly
-// ranked vector from the destination vector into which to insert. This pattern
-// only takes care of this part and forwards the rest of the conversion to
-// another pattern that converts InsertStridedSlice for operands of the same
-// rank.
-//
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have 
diff erent ranks. In this case:
-//   1. the proper subvector is extracted from the destination vector
-//   2. a new InsertStridedSlice op is created to insert the source in the
-//   destination subvector
-//   3. the destination subvector is inserted back in the proper place
-//   4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpDifferentRankRewritePattern
-    : public OpRewritePattern<InsertStridedSliceOp> {
-public:
-  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    auto srcType = op.getSourceVectorType();
-    auto dstType = op.getDestVectorType();
-
-    if (op.offsets().getValue().empty())
-      return failure();
-
-    auto loc = op.getLoc();
-    int64_t rankDiff = dstType.getRank() - srcType.getRank();
-    assert(rankDiff >= 0);
-    if (rankDiff == 0)
-      return failure();
-
-    int64_t rankRest = dstType.getRank() - rankDiff;
-    // Extract / insert the subvector of matching rank and InsertStridedSlice
-    // on it.
-    Value extracted =
-        rewriter.create<ExtractOp>(loc, op.dest(),
-                                   getI64SubArray(op.offsets(), /*dropFront=*/0,
-                                                  /*dropBack=*/rankRest));
-    // A 
diff erent pattern will kick in for InsertStridedSlice with matching
-    // ranks.
-    auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
-        loc, op.source(), extracted,
-        getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
-        getI64SubArray(op.strides(), /*dropFront=*/0));
-    rewriter.replaceOpWithNewOp<InsertOp>(
-        op, stridedSliceInnerOp.getResult(), op.dest(),
-        getI64SubArray(op.offsets(), /*dropFront=*/0,
-                       /*dropBack=*/rankRest));
-    return success();
-  }
-};
-
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have the same rank. In this case, we reduce
-//   1. the proper subvector is extracted from the destination vector
-//   2. a new InsertStridedSlice op is created to insert the source in the
-//   destination subvector
-//   3. the destination subvector is inserted back in the proper place
-//   4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpSameRankRewritePattern
-    : public OpRewritePattern<InsertStridedSliceOp> {
-public:
-  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
-  void initialize() {
-    // This pattern creates recursive InsertStridedSliceOp, but the recursion is
-    // bounded as the rank is strictly decreasing.
-    setHasBoundedRewriteRecursion();
-  }
-
-  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    auto srcType = op.getSourceVectorType();
-    auto dstType = op.getDestVectorType();
-
-    if (op.offsets().getValue().empty())
-      return failure();
-
-    int64_t rankDiff = dstType.getRank() - srcType.getRank();
-    assert(rankDiff >= 0);
-    if (rankDiff != 0)
-      return failure();
-
-    if (srcType == dstType) {
-      rewriter.replaceOp(op, op.source());
-      return success();
-    }
-
-    int64_t offset =
-        op.offsets().getValue().front().cast<IntegerAttr>().getInt();
-    int64_t size = srcType.getShape().front();
-    int64_t stride =
-        op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
-    auto loc = op.getLoc();
-    Value res = op.dest();
-    // For each slice of the source vector along the most major dimension.
-    for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
-         off += stride, ++idx) {
-      // 1. extract the proper subvector (or element) from source
-      Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
-      if (extractedSource.getType().isa<VectorType>()) {
-        // 2. If we have a vector, extract the proper subvector from destination
-        // Otherwise we are at the element level and no need to recurse.
-        Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
-        // 3. Reduce the problem to lowering a new InsertStridedSlice op with
-        // smaller rank.
-        extractedSource = rewriter.create<InsertStridedSliceOp>(
-            loc, extractedSource, extractedDest,
-            getI64SubArray(op.offsets(), /* dropFront=*/1),
-            getI64SubArray(op.strides(), /* dropFront=*/1));
-      }
-      // 4. Insert the extractedSource into the res vector.
-      res = insertOne(rewriter, loc, extractedSource, res, off);
-    }
-
-    rewriter.replaceOp(op, res);
-    return success();
-  }
-};
-
 /// Returns the strides if the memory underlying `memRefType` has a contiguous
 /// static layout.
 static llvm::Optional<SmallVector<int64_t, 4>>
@@ -1189,67 +1027,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   }
 };
 
-/// Progressive lowering of ExtractStridedSliceOp to either:
-///   1. express single offset extract as a direct shuffle.
-///   2. extract + lower rank strided_slice + insert for the n-D case.
-class VectorExtractStridedSliceOpConversion
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
-
-  void initialize() {
-    // This pattern creates recursive ExtractStridedSliceOp, but the recursion
-    // is bounded as the rank is strictly decreasing.
-    setHasBoundedRewriteRecursion();
-  }
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    auto dstType = op.getType();
-
-    assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
-
-    int64_t offset =
-        op.offsets().getValue().front().cast<IntegerAttr>().getInt();
-    int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
-    int64_t stride =
-        op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
-    auto loc = op.getLoc();
-    auto elemType = dstType.getElementType();
-    assert(elemType.isSignlessIntOrIndexOrFloat());
-
-    // Single offset can be more efficiently shuffled.
-    if (op.offsets().getValue().size() == 1) {
-      SmallVector<int64_t, 4> offsets;
-      offsets.reserve(size);
-      for (int64_t off = offset, e = offset + size * stride; off < e;
-           off += stride)
-        offsets.push_back(off);
-      rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
-                                             op.vector(),
-                                             rewriter.getI64ArrayAttr(offsets));
-      return success();
-    }
-
-    // Extract/insert on a lower ranked extract strided slice op.
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, elemType, rewriter.getZeroAttr(elemType));
-    Value res = rewriter.create<SplatOp>(loc, dstType, zero);
-    for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
-         off += stride, ++idx) {
-      Value one = extractOne(rewriter, loc, op.vector(), off);
-      Value extracted = rewriter.create<ExtractStridedSliceOp>(
-          loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
-          getI64SubArray(op.sizes(), /* dropFront=*/1),
-          getI64SubArray(op.strides(), /* dropFront=*/1));
-      res = insertOne(rewriter, loc, extracted, res, idx);
-    }
-    rewriter.replaceOp(op, res);
-    return success();
-  }
-};
-
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions) {
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern,
-               VectorInsertStridedSliceOpDifferentRankRewritePattern,
-               VectorInsertStridedSliceOpSameRankRewritePattern,
-               VectorExtractStridedSliceOpConversion>(ctx);
+  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
+  populateVectorInsertExtractStridedSliceTransforms(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 594e8cd3bb7d3..f620a370c8359 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRVector
-  VectorOps.cpp
+  VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorMultiDimReductionTransforms.cpp
+  VectorOps.cpp
   VectorTransferOpTransforms.cpp
   VectorTransforms.cpp
   VectorUtils.cpp

diff  --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
new file mode 100644
index 0000000000000..1dc04027266a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -0,0 +1,236 @@
+//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+// Helper that picks the proper sequence for inserting.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+                       Value into, int64_t offset) {
+  auto vectorType = into.getType().cast<VectorType>();
+  if (vectorType.getRank() > 1)
+    return rewriter.create<InsertOp>(loc, from, into, offset);
+  return rewriter.create<vector::InsertElementOp>(
+      loc, vectorType, from, into,
+      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+}
+
+// Helper that picks the proper sequence for extracting.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+                        int64_t offset) {
+  auto vectorType = vector.getType().cast<VectorType>();
+  if (vectorType.getRank() > 1)
+    return rewriter.create<ExtractOp>(loc, vector, offset);
+  return rewriter.create<vector::ExtractElementOp>(
+      loc, vectorType.getElementType(), vector,
+      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+}
+
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have 
diff erent ranks.
+///
+/// When ranks are 
diff erent, InsertStridedSlice needs to extract a properly
+/// ranked vector from the destination vector into which to insert. This pattern
+/// only takes care of this extraction part and forwards the rest to
+/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+///
+/// For a k-D source and n-D destination vector (k < n), we emit:
+///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
+///      insert the k-D source.
+///   2. k-D -> (n-1)-D InsertStridedSlice op
+///   3. InsertOp that is the reverse of 1.
+class VectorInsertStridedSliceOpDifferentRankRewritePattern
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto srcType = op.getSourceVectorType();
+    auto dstType = op.getDestVectorType();
+
+    if (op.offsets().getValue().empty())
+      return failure();
+
+    auto loc = op.getLoc();
+    int64_t rankDiff = dstType.getRank() - srcType.getRank();
+    assert(rankDiff >= 0);
+    if (rankDiff == 0)
+      return failure();
+
+    int64_t rankRest = dstType.getRank() - rankDiff;
+    // Extract / insert the subvector of matching rank and InsertStridedSlice
+    // on it.
+    Value extracted =
+        rewriter.create<ExtractOp>(loc, op.dest(),
+                                   getI64SubArray(op.offsets(), /*dropFront=*/0,
+                                                  /*dropBack=*/rankRest));
+
+    // A 
diff erent pattern will kick in for InsertStridedSlice with matching
+    // ranks.
+    auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
+        loc, op.source(), extracted,
+        getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
+        getI64SubArray(op.strides(), /*dropFront=*/0));
+
+    rewriter.replaceOpWithNewOp<InsertOp>(
+        op, stridedSliceInnerOp.getResult(), op.dest(),
+        getI64SubArray(op.offsets(), /*dropFront=*/0,
+                       /*dropBack=*/rankRest));
+    return success();
+  }
+};
+
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have the same rank. For each outermost index in the slice:
+///   begin    end             stride
+/// [offset : offset+size*stride : stride]
+///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
+///   2. InsertStridedSlice (k-1)-D into (n-1)-D
+///   3. the destination subvector is inserted back in the proper place
+///   3. InsertOp that is the reverse of 1.
+class VectorInsertStridedSliceOpSameRankRewritePattern
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  void initialize() {
+    // This pattern creates recursive InsertStridedSliceOp, but the recursion is
+    // bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto srcType = op.getSourceVectorType();
+    auto dstType = op.getDestVectorType();
+
+    if (op.offsets().getValue().empty())
+      return failure();
+
+    int64_t rankDiff = dstType.getRank() - srcType.getRank();
+    assert(rankDiff >= 0);
+    if (rankDiff != 0)
+      return failure();
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, op.source());
+      return success();
+    }
+
+    int64_t offset =
+        op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t size = srcType.getShape().front();
+    int64_t stride =
+        op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+    auto loc = op.getLoc();
+    Value res = op.dest();
+    // For each slice of the source vector along the most major dimension.
+    for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+         off += stride, ++idx) {
+      // 1. extract the proper subvector (or element) from source
+      Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
+      if (extractedSource.getType().isa<VectorType>()) {
+        // 2. If we have a vector, extract the proper subvector from destination
+        // Otherwise we are at the element level and no need to recurse.
+        Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
+        // 3. Reduce the problem to lowering a new InsertStridedSlice op with
+        // smaller rank.
+        extractedSource = rewriter.create<InsertStridedSliceOp>(
+            loc, extractedSource, extractedDest,
+            getI64SubArray(op.offsets(), /* dropFront=*/1),
+            getI64SubArray(op.strides(), /* dropFront=*/1));
+      }
+      // 4. Insert the extractedSource into the res vector.
+      res = insertOne(rewriter, loc, extractedSource, res, off);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+};
+
+/// Progressive lowering of ExtractStridedSliceOp to either:
+///   1. single offset extract as a direct vector::ShuffleOp.
+///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
+///      InsertOp/InsertElementOp for the n-D case.
+class VectorExtractStridedSliceOpRewritePattern
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  void initialize() {
+    // This pattern creates recursive ExtractStridedSliceOp, but the recursion
+    // is bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dstType = op.getType();
+
+    assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+    int64_t offset =
+        op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t stride =
+        op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+    auto loc = op.getLoc();
+    auto elemType = dstType.getElementType();
+    assert(elemType.isSignlessIntOrIndexOrFloat());
+
+    // Single offset can be more efficiently shuffled.
+    if (op.offsets().getValue().size() == 1) {
+      SmallVector<int64_t, 4> offsets;
+      offsets.reserve(size);
+      for (int64_t off = offset, e = offset + size * stride; off < e;
+           off += stride)
+        offsets.push_back(off);
+      rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
+                                             op.vector(),
+                                             rewriter.getI64ArrayAttr(offsets));
+      return success();
+    }
+
+    // Extract/insert on a lower ranked extract strided slice op.
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, elemType, rewriter.getZeroAttr(elemType));
+    Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+    for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+         off += stride, ++idx) {
+      Value one = extractOne(rewriter, loc, op.vector(), off);
+      Value extracted = rewriter.create<ExtractStridedSliceOp>(
+          loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
+          getI64SubArray(op.sizes(), /* dropFront=*/1),
+          getI64SubArray(op.strides(), /* dropFront=*/1));
+      res = insertOne(rewriter, loc, extracted, res, idx);
+    }
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+};
+
+/// Populate the given list with patterns that convert from Vector to LLVM.
+void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
+               VectorInsertStridedSliceOpSameRankRewritePattern,
+               VectorExtractStridedSliceOpRewritePattern>(
+      patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 769a416278150..b0ba6a8f94fb5 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2204,20 +2204,6 @@ class StridedSliceConstantFolder final
   }
 };
 
-// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
-                                              unsigned dropFront = 0,
-                                              unsigned dropBack = 0) {
-  assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
-  auto range = arrayAttr.getAsRange<IntegerAttr>();
-  SmallVector<int64_t, 4> res;
-  res.reserve(arrayAttr.size() - dropFront - dropBack);
-  for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
-       it != eit; ++it)
-    res.push_back((*it).getValue().getSExtValue());
-  return res;
-}
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 75ee23b5ffb3b..d98fa705dbf62 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1034,10 +1034,11 @@ class ShapeCastOp2DDownCastRewritePattern
 };
 
 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively on the way from targeting llvm.matrix intrinsics.
+/// vectors progressively.
 /// This iterates over the most major dimension of the 2-D vector and performs
 /// rewrites into:
-///   vector.strided_slice from 1-D + vector.insert into 2-D
+///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
 class ShapeCastOp2DUpCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 937e303968a36..2659a313c464b 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -362,3 +362,16 @@ bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
          priorWrite.getVectorType() == write.getVectorType() &&
          priorWrite.permutation_map() == write.permutation_map();
 }
+
+SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
+                                             unsigned dropFront,
+                                             unsigned dropBack) {
+  assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+  auto range = arrayAttr.getAsRange<IntegerAttr>();
+  SmallVector<int64_t, 4> res;
+  res.reserve(arrayAttr.size() - dropFront - dropBack);
+  for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+       it != eit; ++it)
+    res.push_back((*it).getValue().getSExtValue());
+  return res;
+}


        


More information about the Mlir-commits mailing list