[Mlir-commits] [mlir] [mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize (PR #92370)

Artem Kroviakov llvmlistbot at llvm.org
Thu May 16 09:37:53 PDT 2024


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/92370

>From d991dc9c9ba4802505a5dc0a3125b9c62b1c1e71 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 16 May 2024 09:37:23 -0700
Subject: [PATCH] [mlir][vector] Add support for linearizing Insert VectorOp in
 VectorLinearize

Building on top of #88204, this commit adds support for InsertOp.
---
 .../Vector/Transforms/VectorLinearize.cpp     | 98 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 29 ++++++
 2 files changed, 126 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 802a64b0805ee..492cbe284cdad 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -44,6 +44,20 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
   return true;
 }
 
+static bool isLessThanOrEqualTargetBitWidth(Type t,
+                                            unsigned targetBitWidth) {
+  VectorType vecType = dyn_cast<VectorType>(t);
+  // Reject index since getElementTypeBitWidth will abort for Index types.
+  if (!vecType || vecType.getElementType().isIndex())
+    return false;
+  // There are no dimension to fold if it is a 0-D vector.
+  if (vecType.getRank() == 0)
+    return false;
+  unsigned trailingVecDimBitWidth =
+      vecType.getShape().back() * vecType.getElementTypeBitWidth();
+  return trailingVecDimBitWidth <= targetBitWidth;
+}
+
 namespace {
 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -355,6 +369,88 @@ struct LinearizeVectorExtract final
     return success();
   }
 
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the InsertOp to a ShuffleOp that works on a
+/// linearized vector.
+/// Following,
+///   vector.insert %source %destination [ position ]
+/// is converted to :
+///   %source_1d = vector.shape_cast %source
+///   %destination_1d = vector.shape_cast %destination
+///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
+///   ] %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original insert.
+struct LinearizeVectorInsert final
+    : public OpConversionPattern<vector::InsertOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorInsert(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
+    assert(!(insertOp.getDestVectorType().isScalable() ||
+             cast<VectorType>(dstTy).isScalable()) &&
+           "scalable vectors are not supported.");
+
+    if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+                                         targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          insertOp, "Can't flatten since targetBitWidth < OpSize");
+
+    // dynamic position is not supported
+    if (insertOp.hasDynamicPosition())
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "dynamic position is not supported.");
+    auto srcTy = insertOp.getSourceType();
+    auto srcAsVec = dyn_cast<VectorType>(srcTy);
+    uint64_t srcSize = 0;
+    if (srcAsVec) {
+      srcSize = srcAsVec.getNumElements();
+    } else {
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "scalars are not supported.");
+    }
+
+    auto dstShape = insertOp.getDestVectorType().getShape();
+    const auto dstSize = insertOp.getDestVectorType().getNumElements();
+    auto dstSizeForOffsets = dstSize;
+
+    // compute linearized offset
+    int64_t linearizedOffset = 0;
+    auto offsetsNd = insertOp.getStaticPosition();
+    for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
+      dstSizeForOffsets /= dstShape[dim];
+      linearizedOffset += offset * dstSizeForOffsets;
+    }
+
+    llvm::SmallVector<int64_t, 2> indices(dstSize);
+    auto origValsUntil = indices.begin();
+    std::advance(origValsUntil, linearizedOffset);
+    std::iota(indices.begin(), origValsUntil,
+              0); // original values that remain [0, offset)
+    auto newValsUntil = origValsUntil;
+    std::advance(newValsUntil, srcSize);
+    std::iota(origValsUntil, newValsUntil,
+              dstSize); // new values [offset, offset+srcNumElements)
+    std::iota(newValsUntil, indices.end(),
+              linearizedOffset + srcSize); // the rest of original values
+                                           // [offset+srcNumElements, end)
+
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
+        rewriter.getI64ArrayAttr(indices));
+
+    return success();
+  }
+
 private:
   unsigned targetVectorBitWidth;
 };
@@ -410,6 +506,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
                    : true;
       });
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorExtractStridedSlice>(
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index b29ceab5783d7..31a59b809a74b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
   return %0 : vector<8x2xf32>
 }
+
+// -----
+// ALL-LABEL: test_vector_insert
+// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
+func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
+  // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+  // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+  // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+  // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+  // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+  // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+  // DEFAULT: return %[[RES]] : vector<2x8x4xf32>
+
+  // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+  // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+  // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+  // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+  // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+  // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+  // BW-128: return %[[RES]] : vector<2x8x4xf32>
+
+  // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
+  // BW-0: return %[[RES]] : vector<2x8x4xf32>
+
+  %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
+  return %0 : vector<2x8x4xf32>
+}



More information about the Mlir-commits mailing list