[Mlir-commits] [mlir] [mlir][Vector] Move insert/extractelement distribution patterns to insert/extract (PR #116425)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 15 22:06:14 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This is a NFC-ish change that moves vector.extractelement/vector.insertelement vector distribution patterns to vector.insert/vector.extract.

Before:

0-d/1-d vector.extract -> vector.extractelement -> distributed vector.extractelement
2-d+ vector.extract -> distributed vector.extract

After:

scalar input vector.extract -> distributed vector.extract
vector.extractelement -> distributed vector.extract
2d+ vector.extract -> distributed vector.extract

The same changes are done for insertelement/insert. The change allows us to remove reliance on vector.extractelement/vector.insertelement, which are soon to be depreciated: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116/8

No extra tests are included because this patch doesn't introduce / remove any functionality. It only changes the chain of lowerings. This change can be completly NFC if we make the distributed operation vector.extractelement/vector.insertelement, but that is slightly weird, because you are going from extractelement -> extract -> extractelement.

---

Patch is 27.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116425.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+137-104) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+12-14) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 682eb82ac58408..bad26b0cc4b383 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     VectorType extractSrcType = extractOp.getSourceVectorType();
     Location loc = extractOp.getLoc();
 
-    // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
-    assert(extractSrcType.getRank() > 0 &&
-           "vector.extract does not support rank 0 sources");
-
-    // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
-    // canonicalized to %v.
-    if (extractOp.getNumIndices() == 0)
+    // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
+    if (extractSrcType.getRank() <= 1) {
       return failure();
-
-    // Rewrite vector.extract with 1d source to vector.extractelement.
-    if (extractSrcType.getRank() == 1) {
-      if (extractOp.hasDynamicPosition())
-        // TODO: Dinamic position not supported yet.
-        return failure();
-
-      assert(extractOp.getNumIndices() == 1 && "expected 1 index");
-      int64_t pos = extractOp.getStaticPosition()[0];
-      rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
-      return success();
     }
 
     // All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
-/// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
-                       PatternBenefit b = 1)
+/// Pattern to move out vector.extract with a scalar result.
+/// Only supports 1-D and 0-D sources for now.
+struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+                      PatternBenefit b = 1)
       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
         warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
-    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
     VectorType extractSrcType = extractOp.getSourceVectorType();
+    // Only supports 1-D or 0-D sources for now.
+    if (extractSrcType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          extractOp, "only 0-D or 1-D source supported for now");
+    }
     // TODO: Supported shuffle types should be parameterizable, similar to
     // `WarpShuffleFromIdxFn`.
     if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     VectorType distributedVecType;
     if (!is0dOrVec1Extract) {
       assert(extractSrcType.getRank() == 1 &&
-             "expected that extractelement src rank is 0 or 1");
+             "expected that extract src rank is 0 or 1");
       if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
         return failure();
       int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Yield source vector and position (if present) from warp op.
     SmallVector<Value> additionalResults{extractOp.getVector()};
     SmallVector<Type> additionalResultTypes{distributedVecType};
-    if (static_cast<bool>(extractOp.getPosition())) {
-      additionalResults.push_back(extractOp.getPosition());
-      additionalResultTypes.push_back(extractOp.getPosition().getType());
-    }
+    additionalResults.append(
+        SmallVector<Value>(extractOp.getDynamicPosition()));
+    additionalResultTypes.append(
+        SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
+
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (is0dOrVec1Extract) {
       Value newExtract;
       if (extractSrcType.getRank() == 1) {
-        newExtract = rewriter.create<vector::ExtractElementOp>(
-            loc, distributedVec,
-            rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
       } else {
-        newExtract =
-            rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+        newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
+                                                        ArrayRef<int64_t>{});
       }
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newExtract);
       return success();
     }
 
+    int64_t staticPos = extractOp.getStaticPosition()[0];
+    OpFoldResult pos = ShapedType::isDynamic(staticPos)
+                           ? (newWarpOp->getResult(newRetIndices[1]))
+                           : OpFoldResult(rewriter.getIndexAttr(staticPos));
     // 1d extract: Distribute the source vector. One lane extracts and shuffles
     // the value to all other lanes.
     int64_t elementsPerLane = distributedVecType.getShape()[0];
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
-    Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane),
-        newWarpOp->getResult(newRetIndices[1]));
+    Value broadcastFromTid = affine::makeComposedAffineApply(
+        rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
     // Extract at position: pos % elementsPerLane
-    Value pos =
+    Value newPos =
         elementsPerLane == 1
             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
-            : rewriter
-                  .create<affine::AffineApplyOp>(
-                      loc, sym0 % elementsPerLane,
-                      newWarpOp->getResult(newRetIndices[1]))
-                  .getResult();
+            : affine::makeComposedAffineApply(rewriter, loc,
+                                              sym0 % elementsPerLane, pos);
     Value extracted =
-        rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+        rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
 
     // Shuffle the extracted value to all lanes.
     Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
-struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+/// Pattern to convert vector.extractelement to vector.extract.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+    if (!operand)
+      return failure();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    rewriter.setInsertionPoint(extractOp);
+    if (auto pos = extractOp.getPosition()) {
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
+    } else {
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+    }
+    return success();
+  }
+};
+
+/// Pattern to move out vector.insert with a scalar input.
+/// Only supports 1-D and 0-D destinations for now.
+struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
-    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
     VectorType vecType = insertOp.getDestVectorType();
     VectorType distrType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
-    bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+    // Only supports 1-D or 0-D destinations for now.
+    if (vecType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          insertOp, "only 0-D or 1-D source supported for now");
+    }
 
     // Yield destination vector, source scalar and position from warp op.
     SmallVector<Value> additionalResults{insertOp.getDest(),
                                          insertOp.getSource()};
     SmallVector<Type> additionalResultTypes{distrType,
                                             insertOp.getSource().getType()};
-    if (hasPos) {
-      additionalResults.push_back(insertOp.getPosition());
-      additionalResultTypes.push_back(insertOp.getPosition().getType());
-    }
+    additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
+    additionalResultTypes.append(
+        SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
+
     Location loc = insertOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
     Value newSource = newWarpOp->getResult(newRetIndices[1]);
-    Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
     rewriter.setInsertionPointAfter(newWarpOp);
 
+    OpFoldResult pos;
+    if (vecType.getRank() != 0) {
+      int64_t staticPos = insertOp.getStaticPosition()[0];
+      pos = ShapedType::isDynamic(staticPos)
+                ? (newWarpOp->getResult(newRetIndices[2]))
+                : OpFoldResult(rewriter.getIndexAttr(staticPos));
+    }
+
+    // This condition is always true for 0-d vectors.
     if (vecType == distrType) {
-      // Broadcast: Simply move the vector.inserelement op out.
-      Value newInsert = rewriter.create<vector::InsertElementOp>(
-          loc, newSource, distributedVec, newPos);
+      Value newInsert;
+      if (pos) {
+        newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+                                                      distributedVec, pos);
+      } else {
+        newInsert = rewriter.create<vector::InsertOp>(
+            loc, newSource, distributedVec, ArrayRef<int64_t>{});
+      }
+      // Broadcast: Simply move the vector.insert op out.
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newInsert);
       return success();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     int64_t elementsPerLane = distrType.getShape()[0];
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
-    Value insertingLane = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane), newPos);
+    Value insertingLane = affine::makeComposedAffineApply(
+        rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
     // Insert position: pos % elementsPerLane
-    Value pos =
-        elementsPerLane == 1
-            ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
-            : rewriter
-                  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
-                                                 newPos)
-                  .getResult();
+    OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, sym0 % elementsPerLane, pos);
     Value isInsertingLane = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
     Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
                 loc, isInsertingLane,
                 /*thenBuilder=*/
                 [&](OpBuilder &builder, Location loc) {
-                  Value newInsert = builder.create<vector::InsertElementOp>(
-                      loc, newSource, distributedVec, pos);
+                  Value newInsert = builder.create<vector::InsertOp>(
+                      loc, newSource, distributedVec, newPos);
                   builder.create<scf::YieldOp>(loc, newInsert);
                 },
                 /*elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
     Location loc = insertOp.getLoc();
 
-    // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
-    if (insertOp.getNumIndices() == 0)
+    // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
+    if (insertOp.getDestVectorType().getRank() <= 1) {
       return failure();
-
-    // Rewrite vector.insert with 1d dest to vector.insertelement.
-    if (insertOp.getDestVectorType().getRank() == 1) {
-      if (insertOp.hasDynamicPosition())
-        // TODO: Dinamic position not supported yet.
-        return failure();
-
-      assert(insertOp.getNumIndices() == 1 && "expected 1 index");
-      int64_t pos = insertOp.getStaticPosition()[0];
-      rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
-      return success();
     }
 
+    // All following cases are 2d or higher dimensional source vectors.
+
     if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
       // There is no distribution, this is a broadcast. Simply move the insert
       // out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+    if (!operand)
+      return failure();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+    rewriter.setInsertionPoint(insertOp);
+    if (auto pos = insertOp.getPosition()) {
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
+    } else {
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(),
+          ArrayRef<int64_t>{});
+    }
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
-/// the scf.ForOp is the last operation in the region so that it doesn't change
-/// the order of execution. This creates a new scf.for region after the
+/// the scf.ForOp is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.for region after the
 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
 /// WarpExecuteOnLane0Op region. Example:
 /// ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (!forOp)
       return failure();
     // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to the
-    // new op.
+    // Those Value needs to be returned by the original warpOp and passed to
+    // the new op.
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
 
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op region
-    // inside.
+    // Create a new for op outside the region with a WarpExecuteOnLane0Op
+    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
 };
 
 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
-/// The vector is reduced in parallel. Currently limited to vector size matching
-/// the warpOp size. E.g.:
+/// The vector is reduced in parallel. Currently limited to vector size
+/// matching the warpOp size. E.g.:
 /// ```
 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
 ///   %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
-          patterns.getContext(), benefit);
-  patterns.add<WarpOpExtractElement>(patterns.getContext(),
-                                     warpShuffleFromIdxFn, benefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+      patterns.getContext(), benefit);
+  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
+                                    benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3acddd6e54639e..b4491812dc26cb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -783,7 +783,7 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<64xf32>
 //       CHECK-PROP:     vector.yield %[[V]] : vector<64xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+//       CHE...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116425


More information about the Mlir-commits mailing list