[Mlir-commits] [mlir] [mlir][Vector] Move insert/extractelement distribution patterns to insert/extract (PR #116425)

Kunwar Grover llvmlistbot at llvm.org
Mon Nov 18 02:20:10 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/116425

>From 6f867b11bb5b93474d46d849a91af7c75f288bd1 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 15 Nov 2024 19:18:51 +0000
Subject: [PATCH 1/2] [mlir][Vector] Move insert/extractelement distribution
 patterns to insert/extract

---
 .../Vector/Transforms/VectorDistribute.cpp    | 241 ++++++++++--------
 .../Vector/vector-warp-distribute.mlir        |  26 +-
 2 files changed, 149 insertions(+), 118 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 682eb82ac58408..bad26b0cc4b383 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     VectorType extractSrcType = extractOp.getSourceVectorType();
     Location loc = extractOp.getLoc();
 
-    // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
-    assert(extractSrcType.getRank() > 0 &&
-           "vector.extract does not support rank 0 sources");
-
-    // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
-    // canonicalized to %v.
-    if (extractOp.getNumIndices() == 0)
+    // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
+    if (extractSrcType.getRank() <= 1) {
       return failure();
-
-    // Rewrite vector.extract with 1d source to vector.extractelement.
-    if (extractSrcType.getRank() == 1) {
-      if (extractOp.hasDynamicPosition())
-        // TODO: Dinamic position not supported yet.
-        return failure();
-
-      assert(extractOp.getNumIndices() == 1 && "expected 1 index");
-      int64_t pos = extractOp.getStaticPosition()[0];
-      rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
-      return success();
     }
 
     // All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
-/// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
-                       PatternBenefit b = 1)
+/// Pattern to move out vector.extract with a scalar result.
+/// Only supports 1-D and 0-D sources for now.
+struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+                      PatternBenefit b = 1)
       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
         warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
-    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
     VectorType extractSrcType = extractOp.getSourceVectorType();
+    // Only supports 1-D or 0-D sources for now.
+    if (extractSrcType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          extractOp, "only 0-D or 1-D source supported for now");
+    }
     // TODO: Supported shuffle types should be parameterizable, similar to
     // `WarpShuffleFromIdxFn`.
     if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     VectorType distributedVecType;
     if (!is0dOrVec1Extract) {
       assert(extractSrcType.getRank() == 1 &&
-             "expected that extractelement src rank is 0 or 1");
+             "expected that extract src rank is 0 or 1");
       if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
         return failure();
       int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Yield source vector and position (if present) from warp op.
     SmallVector<Value> additionalResults{extractOp.getVector()};
     SmallVector<Type> additionalResultTypes{distributedVecType};
-    if (static_cast<bool>(extractOp.getPosition())) {
-      additionalResults.push_back(extractOp.getPosition());
-      additionalResultTypes.push_back(extractOp.getPosition().getType());
-    }
+    additionalResults.append(
+        SmallVector<Value>(extractOp.getDynamicPosition()));
+    additionalResultTypes.append(
+        SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
+
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (is0dOrVec1Extract) {
       Value newExtract;
       if (extractSrcType.getRank() == 1) {
-        newExtract = rewriter.create<vector::ExtractElementOp>(
-            loc, distributedVec,
-            rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
       } else {
-        newExtract =
-            rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
+                                                        ArrayRef<int64_t>{});
       }
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newExtract);
       return success();
     }
 
+    int64_t staticPos = extractOp.getStaticPosition()[0];
+    OpFoldResult pos = ShapedType::isDynamic(staticPos)
+                           ? (newWarpOp->getResult(newRetIndices[1]))
+                           : OpFoldResult(rewriter.getIndexAttr(staticPos));
     // 1d extract: Distribute the source vector. One lane extracts and shuffles
     // the value to all other lanes.
     int64_t elementsPerLane = distributedVecType.getShape()[0];
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
-    Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane),
-        newWarpOp->getResult(newRetIndices[1]));
+    Value broadcastFromTid = affine::makeComposedAffineApply(
+        rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
     // Extract at position: pos % elementsPerLane
-    Value pos =
+    Value newPos =
         elementsPerLane == 1
             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
-            : rewriter
-                  .create<affine::AffineApplyOp>(
-                      loc, sym0 % elementsPerLane,
-                      newWarpOp->getResult(newRetIndices[1]))
-                  .getResult();
+            : affine::makeComposedAffineApply(rewriter, loc,
+                                              sym0 % elementsPerLane, pos);
     Value extracted =
-        rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+        rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
 
     // Shuffle the extracted value to all lanes.
     Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
-struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+/// Pattern to convert vector.extractelement to vector.extract.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+    if (!operand)
+      return failure();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    rewriter.setInsertionPoint(extractOp);
+    if (auto pos = extractOp.getPosition()) {
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
+    } else {
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+    }
+    return success();
+  }
+};
+
+/// Pattern to move out vector.insert with a scalar input.
+/// Only supports 1-D and 0-D destinations for now.
+struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
-    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
     VectorType vecType = insertOp.getDestVectorType();
     VectorType distrType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
-    bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+    // Only supports 1-D or 0-D destinations for now.
+    if (vecType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          insertOp, "only 0-D or 1-D source supported for now");
+    }
 
     // Yield destination vector, source scalar and position from warp op.
     SmallVector<Value> additionalResults{insertOp.getDest(),
                                          insertOp.getSource()};
     SmallVector<Type> additionalResultTypes{distrType,
                                             insertOp.getSource().getType()};
-    if (hasPos) {
-      additionalResults.push_back(insertOp.getPosition());
-      additionalResultTypes.push_back(insertOp.getPosition().getType());
-    }
+    additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
+    additionalResultTypes.append(
+        SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
+
     Location loc = insertOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
     Value newSource = newWarpOp->getResult(newRetIndices[1]);
-    Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
     rewriter.setInsertionPointAfter(newWarpOp);
 
+    OpFoldResult pos;
+    if (vecType.getRank() != 0) {
+      int64_t staticPos = insertOp.getStaticPosition()[0];
+      pos = ShapedType::isDynamic(staticPos)
+                ? (newWarpOp->getResult(newRetIndices[2]))
+                : OpFoldResult(rewriter.getIndexAttr(staticPos));
+    }
+
+    // This condition is always true for 0-d vectors.
     if (vecType == distrType) {
-      // Broadcast: Simply move the vector.inserelement op out.
-      Value newInsert = rewriter.create<vector::InsertElementOp>(
-          loc, newSource, distributedVec, newPos);
+      Value newInsert;
+      if (pos) {
+        newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+                                                      distributedVec, pos);
+      } else {
+        newInsert = rewriter.create<vector::InsertOp>(
+            loc, newSource, distributedVec, ArrayRef<int64_t>{});
+      }
+      // Broadcast: Simply move the vector.insert op out.
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newInsert);
       return success();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     int64_t elementsPerLane = distrType.getShape()[0];
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
-    Value insertingLane = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane), newPos);
+    Value insertingLane = affine::makeComposedAffineApply(
+        rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
     // Insert position: pos % elementsPerLane
-    Value pos =
-        elementsPerLane == 1
-            ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
-            : rewriter
-                  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
-                                                 newPos)
-                  .getResult();
+    OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, sym0 % elementsPerLane, pos);
     Value isInsertingLane = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
     Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
                 loc, isInsertingLane,
                 /*thenBuilder=*/
                 [&](OpBuilder &builder, Location loc) {
-                  Value newInsert = builder.create<vector::InsertElementOp>(
-                      loc, newSource, distributedVec, pos);
+                  Value newInsert = builder.create<vector::InsertOp>(
+                      loc, newSource, distributedVec, newPos);
                   builder.create<scf::YieldOp>(loc, newInsert);
                 },
                 /*elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
     Location loc = insertOp.getLoc();
 
-    // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
-    if (insertOp.getNumIndices() == 0)
+    // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
+    if (insertOp.getDestVectorType().getRank() <= 1) {
       return failure();
-
-    // Rewrite vector.insert with 1d dest to vector.insertelement.
-    if (insertOp.getDestVectorType().getRank() == 1) {
-      if (insertOp.hasDynamicPosition())
-        // TODO: Dinamic position not supported yet.
-        return failure();
-
-      assert(insertOp.getNumIndices() == 1 && "expected 1 index");
-      int64_t pos = insertOp.getStaticPosition()[0];
-      rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
-      return success();
     }
 
+    // All following cases are 2d or higher dimensional source vectors.
+
     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
       // There is no distribution, this is a broadcast. Simply move the insert
       // out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+    if (!operand)
+      return failure();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+    rewriter.setInsertionPoint(insertOp);
+    if (auto pos = insertOp.getPosition()) {
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
+    } else {
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(),
+          ArrayRef<int64_t>{});
+    }
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
-/// the scf.ForOp is the last operation in the region so that it doesn't change
-/// the order of execution. This creates a new scf.for region after the
+/// the scf.ForOp is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.for region after the
 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
 /// WarpExecuteOnLane0Op region. Example:
 /// ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (!forOp)
       return failure();
     // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to the
-    // new op.
+    // Those Value needs to be returned by the original warpOp and passed to
+    // the new op.
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
 
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op region
-    // inside.
+    // Create a new for op outside the region with a WarpExecuteOnLane0Op
+    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
 };
 
 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
-/// The vector is reduced in parallel. Currently limited to vector size matching
-/// the warpOp size. E.g.:
+/// The vector is reduced in parallel. Currently limited to vector size
+/// matching the warpOp size. E.g.:
 /// ```
 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
 ///   %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
-          patterns.getContext(), benefit);
-  patterns.add<WarpOpExtractElement>(patterns.getContext(),
-                                     warpShuffleFromIdxFn, benefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+      patterns.getContext(), benefit);
+  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
+                                    benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3acddd6e54639e..b4491812dc26cb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -783,7 +783,7 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<64xf32>
 //       CHECK-PROP:     vector.yield %[[V]] : vector<64xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
 //       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[E]], %[[C5_I32]]
 //       CHECK-PROP:   return %[[SHUFFLED]] : f32
 func.func @vector_extract_1d(%laneid: index) -> (f32) {
@@ -874,7 +874,7 @@ func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>
 //       CHECK-PROP:     vector.yield %[[V]] : vector<f32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][] : f32 from vector<f32>
 //       CHECK-PROP:   return %[[E]] : f32
 func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
@@ -888,12 +888,11 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
 // -----
 
 // CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
-//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
 //       CHECK-PROP:     vector.yield %[[V]] : vector<1xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][0] : f32 from vector<1xf32>
 //       CHECK-PROP:   return %[[E]] : f32
 func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
@@ -918,7 +917,7 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
 //       CHECK-PROP:   }
 //       CHECK-PROP:   %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]]
 //       CHECK-PROP:   %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]]
-//       CHECK-PROP:   %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32>
+//       CHECK-PROP:   %[[EXTRACTED:.*]] = vector.extract %[[W]][%[[DISTR_POS]]] : f32 from vector<3xf32>
 //       CHECK-PROP:   %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
 //       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
 //       CHECK-PROP:   return %[[SHUFFLED]]
@@ -938,7 +937,7 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
 // CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
 //       CHECK-PROP:   vector.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
 //       CHECK-PROP:     "some_def"
-//       CHECK-PROP:     vector.extractelement
+//       CHECK-PROP:     vector.extract
 //       CHECK-PROP:     vector.yield {{.*}} : index
 //       CHECK-PROP:   }
 func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
@@ -1151,7 +1150,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
 //       CHECK-PROP:   %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]]
 //       CHECK-PROP:   %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index
 //       CHECK-PROP:   %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
-//       CHECK-PROP:     %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index]
+//       CHECK-PROP:     %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [%[[INSERTING_POS]]]
 //       CHECK-PROP:     scf.yield %[[INSERT]]
 //       CHECK-PROP:   } else {
 //       CHECK-PROP:     scf.yield %[[W]]#0
@@ -1175,7 +1174,7 @@ func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
 //       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
-//       CHECK-PROP:   vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32>
+//       CHECK-PROP:   vector.insert %[[W]]#1, %[[W]]#0 [%[[POS]]] : f32 into vector<96xf32>
 func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
     %0 = "some_def"() : () -> (vector<96xf32>)
@@ -1193,7 +1192,7 @@ func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (ve
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
 //       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
-//       CHECK-PROP:   vector.insertelement %[[W]]#1, %[[W]]#0[] : vector<f32>
+//       CHECK-PROP:   vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
 func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
     %0 = "some_def"() : () -> (vector<f32>)
@@ -1208,7 +1207,6 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
 
 // CHECK-PROP-LABEL: func @vector_insert_1d(
 //  CHECK-PROP-SAME:     %[[LANEID:.*]]: index
-//   CHECK-PROP-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-PROP-DAG:   %[[C26:.*]] = arith.constant 26 : index
 //       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
@@ -1216,7 +1214,7 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
 //       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
 //       CHECK-PROP:   %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]]
 //       CHECK-PROP:   %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
-//       CHECK-PROP:     %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index]
+//       CHECK-PROP:     %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [1]
 //       CHECK-PROP:     scf.yield %[[INSERT]]
 //       CHECK-PROP:   } else {
 //       CHECK-PROP:     scf.yield %[[W]]#0
@@ -1316,14 +1314,14 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
 //       CHECK-PROP:     %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
 //       CHECK-PROP:     %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
 //       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
-//       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
+//       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex>
 //       CHECK-PROP:     vector.yield %[[EXTRACTELT]] : index
 //       CHECK-PROP:   %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
 //       CHECK-PROP:   %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]],
 //       CHECK-PROP:   return %[[TRANSFERREAD]]
 func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
   %0 = gpu.thread_id  x
-  %c0_i32 = arith.constant 0 : i32
+  %c0_i32 = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0> : vector<1x64xi32>
   %cst_0 = arith.constant dense<true> : vector<1x64xi1>
@@ -1336,7 +1334,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1
     %28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
     %29 = vector.extract %28[0] : vector<64xi32> from vector<1x64xi32>
     %30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
-    %36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex>
+    %36 = vector.extractelement %30[%c0_i32 : index] : vector<64xindex>
     %37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
     vector.yield %37 : vector<64xf32>
   }

>From 823909cab5d5732f9eda7c34981cd6a98df4f1fd Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 18 Nov 2024 10:19:39 +0000
Subject: [PATCH 2/2] Address comments

---
 .../Vector/Transforms/VectorDistribute.cpp    | 41 ++++++++-----------
 1 file changed, 17 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bad26b0cc4b383..dc5eb2527f949a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1355,12 +1355,9 @@ struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // All lanes extract the scalar.
     if (is0dOrVec1Extract) {
       Value newExtract;
-      if (extractSrcType.getRank() == 1) {
-        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
-      } else {
-        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
-                                                        ArrayRef<int64_t>{});
-      }
+      SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
+      newExtract =
+          rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newExtract);
       return success();
@@ -1408,14 +1405,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (!operand)
       return failure();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    rewriter.setInsertionPoint(extractOp);
+    SmallVector<OpFoldResult> indices;
     if (auto pos = extractOp.getPosition()) {
-      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-          extractOp, extractOp.getVector(), pos);
-    } else {
-      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-          extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+      indices.push_back(pos);
     }
+    rewriter.setInsertionPoint(extractOp);
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, extractOp.getVector(), indices);
     return success();
   }
 };
@@ -1472,13 +1468,12 @@ struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // This condition is always true for 0-d vectors.
     if (vecType == distrType) {
       Value newInsert;
+      SmallVector<OpFoldResult> indices;
       if (pos) {
-        newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
-                                                      distributedVec, pos);
-      } else {
-        newInsert = rewriter.create<vector::InsertOp>(
-            loc, newSource, distributedVec, ArrayRef<int64_t>{});
+        indices.push_back(pos);
       }
+      newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+                                                    distributedVec, indices);
       // Broadcast: Simply move the vector.insert op out.
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newInsert);
@@ -1640,15 +1635,13 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (!operand)
       return failure();
     auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
-    rewriter.setInsertionPoint(insertOp);
+    SmallVector<OpFoldResult> indices;
     if (auto pos = insertOp.getPosition()) {
-      rewriter.replaceOpWithNewOp<vector::InsertOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
-    } else {
-      rewriter.replaceOpWithNewOp<vector::InsertOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          ArrayRef<int64_t>{});
+      indices.push_back(pos);
     }
+    rewriter.setInsertionPoint(insertOp);
+    rewriter.replaceOpWithNewOp<vector::InsertOp>(
+        insertOp, insertOp.getSource(), insertOp.getDest(), indices);
     return success();
   }
 };



More information about the Mlir-commits mailing list