[Mlir-commits] [mlir] 91f62f0 - [mlir][vector] Fix distribution of scf.for with value coming from above

Thomas Raoux llvmlistbot at llvm.org
Tue Nov 1 21:16:44 PDT 2022


Author: Thomas Raoux
Date: 2022-11-02T04:15:18Z
New Revision: 91f62f0e352a4f5c755f1cbec6f27e40a60ff109

URL: https://github.com/llvm/llvm-project/commit/91f62f0e352a4f5c755f1cbec6f27e40a60ff109
DIFF: https://github.com/llvm/llvm-project/commit/91f62f0e352a4f5c755f1cbec6f27e40a60ff109.diff

LOG: [mlir][vector] Fix distribution of scf.for with value coming from above

When a value used in the forOp is defined outside the region but within
the parent warpOp we need to return and distribute the value to pass it
to new operations created within the loop.
Also simplify the lambda interface.

Differential Revision: https://reviews.llvm.org/D137146

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 204b322e2deae..49e34274f9891 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,7 +40,7 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
     const WarpExecuteOnLane0LoweringOptions &options,
     PatternBenefit benefit = 1);
 
-using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
+using DistributionMapFn = std::function<AffineMap(Value)>;
 
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
@@ -67,9 +67,12 @@ void populateDistributeTransferWriteOpPatterns(
 /// region.
 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
 
-/// Collect patterns to propagate warp distribution.
+/// Collect patterns to propagate warp distribution. `distributionMapFn` is used
+/// to decide how a value should be distributed when this cannot be inferred
+/// from its uses.
 void populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &pattern, PatternBenefit benefit = 1);
+    RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit = 1);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f730044abcf85..6dfdf766a2f62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Transforms/SideEffectUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include <utility>
@@ -421,6 +422,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
   return newWriteOp;
 }
 
+/// Return the distributed vector type based on the original type and the
+/// distribution map. The map is expected to have a dimension equal to the
+/// original type rank and should be a projection where the results are the
+/// distributed dimensions. The number of results should be equal to the number
+/// of warp sizes which is currently limited to 1.
+/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
+/// and a warp size of 16 would distribute the second dimension (associated to
+/// d1) and return vector<16x2x64>
+static VectorType getDistributedType(VectorType originalType, AffineMap map,
+                                     int64_t warpSize) {
+  if (map.getNumResults() != 1)
+    return VectorType();
+  SmallVector<int64_t> targetShape(originalType.getShape().begin(),
+                                   originalType.getShape().end());
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    unsigned position = map.getDimPosition(i);
+    if (targetShape[position] % warpSize != 0)
+      return VectorType();
+    targetShape[position] = targetShape[position] / warpSize;
+  }
+  VectorType targetType =
+      VectorType::get(targetShape, originalType.getElementType());
+  return targetType;
+}
+
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
 /// Example:
@@ -456,29 +482,19 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (writtenVectorType.getRank() == 0)
       return failure();
 
-    // 2. Compute the distribution map.
-    AffineMap map = distributionMapFn(writeOp);
-    if (map.getNumResults() != 1)
-      return writeOp->emitError("multi-dim distribution not implemented yet");
-
-    // 3. Compute the targetType using the distribution map.
-    SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
-                                     writtenVectorType.getShape().end());
-    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-      unsigned position = map.getDimPosition(i);
-      if (targetShape[position] % warpOp.getWarpSize() != 0)
-        return failure();
-      targetShape[position] = targetShape[position] / warpOp.getWarpSize();
-    }
+    // 2. Compute the distributed type.
+    AffineMap map = distributionMapFn(writeOp.getVector());
     VectorType targetType =
-        VectorType::get(targetShape, writtenVectorType.getElementType());
+        getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
+    if (!targetType)
+      return failure();
 
-    // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
+    // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
     // the rest.
     vector::TransferWriteOp newWriteOp =
         cloneWriteOp(rewriter, warpOp, writeOp, targetType);
 
-    // 5. Reindex the write using the distribution map.
+    // 4. Reindex the write using the distribution map.
     auto newWarpOp =
         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
     rewriter.setInsertionPoint(newWriteOp);
@@ -494,7 +510,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
         continue;
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
-      auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
+      auto scale =
+          rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
       indices[indexPos] =
           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
                                   {indices[indexPos], newWarpOp.getLaneid()});
@@ -956,6 +973,10 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///  }
 /// ```
 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
+
+  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+        distributionMapFn(std::move(fn)) {}
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -966,6 +987,35 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
+    // Collect Values that come from the warp op but are outside the forOp.
+    // Those Value needs to be returned by the original warpOp and passed to the
+    // new op.
+    llvm::SmallSetVector<Value, 32> escapingValues;
+    SmallVector<Type> inputTypes;
+    SmallVector<Type> distTypes;
+    mlir::visitUsedValuesDefinedAbove(
+        forOp.getBodyRegion(), [&](OpOperand *operand) {
+          Operation *parent = operand->get().getParentRegion()->getParentOp();
+          if (warpOp->isAncestor(parent)) {
+            if (!escapingValues.insert(operand->get()))
+              return;
+            Type distType = operand->get().getType();
+            if (auto vecType = distType.cast<VectorType>()) {
+              AffineMap map = distributionMapFn(operand->get());
+              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+            }
+            inputTypes.push_back(operand->get().getType());
+            distTypes.push_back(distType);
+          }
+        });
+
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
+        newRetIndices);
+    yield = cast<vector::YieldOp>(
+        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
     SmallVector<Value> newOperands;
     SmallVector<unsigned> resultIdx;
     // Collect all the outputs coming from the forOp.
@@ -973,28 +1023,42 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
         continue;
       auto forResult = yieldOperand.get().cast<OpResult>();
-      newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
+      newOperands.push_back(
+          newWarpOp.getResult(yieldOperand.getOperandNumber()));
       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
       resultIdx.push_back(yieldOperand.getOperandNumber());
     }
+
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
     // inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newOperands);
     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
+
+    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
+                                 newForOp.getRegionIterArgs().end());
+    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
+                                    forOp.getResultTypes().end());
+    llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
+    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
+      warpInput.push_back(newWarpOp.getResult(retIdx));
+      argIndexMapping[escapingValues[i]] = warpInputType.size();
+      warpInputType.push_back(inputTypes[i]);
+    }
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
-        warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
-        warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
-        forOp.getResultTypes());
+        newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
+        newWarpOp.getWarpSize(), warpInput, warpInputType);
 
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
     for (Value args : innerWarp.getBody()->getArguments()) {
       argMapping.push_back(args);
     }
+    argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
@@ -1008,12 +1072,23 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.eraseOp(forOp);
     // Replace the warpOp result coming from the original ForOp.
     for (const auto &res : llvm::enumerate(resultIdx)) {
-      warpOp.getResult(res.value())
+      newWarpOp.getResult(res.value())
           .replaceAllUsesWith(newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
+      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
     }
+    newForOp.walk([&](Operation *op) {
+      for (OpOperand &operand : op->getOpOperands()) {
+        auto it = argIndexMapping.find(operand.get());
+        if (it == argIndexMapping.end())
+          continue;
+        operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+      }
+    });
     return success();
   }
+
+private:
+  DistributionMapFn distributionMapFn;
 };
 
 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
@@ -1119,11 +1194,14 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 }
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
-               WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
-      patterns.getContext(), benefit);
+               WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(),
+                                                     benefit);
+  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
+                               benefit);
 }
 
 void mlir::vector::populateDistributeReduction(

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 49c36fe18c90d..daebccd92008d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -349,6 +349,40 @@ func.func @warp_scf_for(%arg0: index) {
 
 // -----
 
+// CHECK-PROP-LABEL:   func @warp_scf_for_use_from_above(
+// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:   %[[USE:.*]] = "some_def_above"() : () -> vector<128xf32>
+// CHECK-PROP:   vector.yield %[[INI1]], %[[USE]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]#0) -> (vector<4xf32>) {
+// CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]], %[[INI]]#1 : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:    ^bb0(%[[ARG0:.*]]: vector<128xf32>, %[[ARG1:.*]]: vector<128xf32>):
+// CHECK-PROP:      %[[ACC:.*]] = "some_def"(%[[ARG0]], %[[ARG1]]) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP:      vector.yield %[[ACC]] : vector<128xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   scf.yield %[[W]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_use_from_above(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %use_from_above = "some_def_above"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4, %use_from_above) : (vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    vector.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(
 // CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
 // CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 5547a964cf2cb..b66b2fe9ef7f8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -746,24 +746,26 @@ struct TestVectorDistribution
       }
     });
     MLIRContext *ctx = &getContext();
+    auto distributionFn = [](Value val) {
+      // Create a map (d0, d1) -> (d1) to distribute along the inner
+      // dimension. Once we support n-d distribution we can add more
+      // complex cases.
+      VectorType vecType = val.getType().dyn_cast<VectorType>();
+      int64_t vecRank = vecType ? vecType.getRank() : 0;
+      OpBuilder builder(val.getContext());
+      if (vecRank == 0)
+        return AffineMap::get(val.getContext());
+      return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+    };
     if (distributeTransferWriteOps) {
-      auto distributionFn = [](vector::TransferWriteOp writeOp) {
-        // Create a map (d0, d1) -> (d1) to distribute along the inner
-        // dimension. Once we support n-d distribution we can add more
-        // complex cases.
-        int64_t vecRank = writeOp.getVectorType().getRank();
-        OpBuilder builder(writeOp.getContext());
-        auto map =
-            AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
-        return map;
-      };
       RewritePatternSet patterns(ctx);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }
     if (propagateDistribution) {
       RewritePatternSet patterns(ctx);
-      vector::populatePropagateWarpVectorDistributionPatterns(patterns);
+      vector::populatePropagateWarpVectorDistributionPatterns(patterns,
+                                                              distributionFn);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }


        


More information about the Mlir-commits mailing list