[Mlir-commits] [mlir] 20df17f - [mlir][vector] Extend WarpExecutionOnLane0 pattern support to allow deduplicating identical yield values.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 9 06:53:45 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-09T06:53:36-07:00
New Revision: 20df17fd2db16b4df20fe275bf9f062eaecf9745

URL: https://github.com/llvm/llvm-project/commit/20df17fd2db16b4df20fe275bf9f062eaecf9745
DIFF: https://github.com/llvm/llvm-project/commit/20df17fd2db16b4df20fe275bf9f062eaecf9745.diff

LOG: [mlir][vector] Extend WarpExecutionOnLane0 pattern support to allow deduplicating identical yield values.

Differential Revision: https://reviews.llvm.org/D133573

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f356248f18311..2c757be09a6a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -753,27 +753,50 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<Type> resultTypes;
-    SmallVector<Value> yieldValues;
+    SmallVector<Type> newResultTypes;
+    newResultTypes.reserve(warpOp->getNumResults());
+    SmallVector<Value> newYieldValues;
+    newYieldValues.reserve(warpOp->getNumResults());
+    DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
+    DenseMap<OpResult, int64_t> dedupResultPositionMap;
     auto yield = cast<vector::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
+    // Some values may be yielded multiple times and correspond to multiple
+    // results. Deduplicating occurs by taking each result with its matching
+    // yielded value, and:
+    //   1. recording the unique first position at which the value is yielded.
+    //   2. recording for the result, the first position at which the dedup'ed
+    //      value is yielded.
+    //   3. skipping from the new result types / new yielded values any result
+    //      that has no use or whose yielded value has already been seen.
     for (OpResult result : warpOp.getResults()) {
-      if (result.use_empty())
+      Value yieldOperand = yield.getOperand(result.getResultNumber());
+      auto it = dedupYieldOperandPositionMap.insert(
+          std::make_pair(yieldOperand, newResultTypes.size()));
+      dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
+      if (result.use_empty() || !it.second)
         continue;
-      resultTypes.push_back(result.getType());
-      yieldValues.push_back(yield.getOperand(result.getResultNumber()));
+      newResultTypes.push_back(result.getType());
+      newYieldValues.push_back(yieldOperand);
     }
-    if (yield.getNumOperands() == yieldValues.size())
+    // No modification, exit early.
+    if (yield.getNumOperands() == newYieldValues.size())
       return failure();
+    // Move the body of the old warpOp to a new warpOp.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, yieldValues, resultTypes);
-    unsigned resultIndex = 0;
+        rewriter, warpOp, newYieldValues, newResultTypes);
+    // Replace results of the old warpOp by the new, deduplicated results.
+    SmallVector<Value> newValues;
+    newValues.reserve(warpOp->getNumResults());
     for (OpResult result : warpOp.getResults()) {
       if (result.use_empty())
-        continue;
-      result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
+        newValues.push_back(Value());
+      else
+        newValues.push_back(
+            newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
     }
-    rewriter.eraseOp(warpOp);
+    rewriter.replaceOp(warpOp, newValues);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 09d41db642fb2..5a70ae8a4994a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -669,3 +669,25 @@ func.func @dont_duplicate_read(
   }
   return %r : vector<1xf32>
 }
+
+// -----
+
+// CHECK-PROP:   func @dedup
+func.func @dedup(%laneid: index, %v0: vector<4xf32>, %v1: vector<4xf32>) 
+    -> (vector<1xf32>, vector<1xf32>) {
+
+  // CHECK-PROP: %[[SINGLE_RES:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>) {
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
+      args(%v0, %v1 : vector<4xf32>, vector<4xf32>) -> (vector<1xf32>, vector<1xf32>) {
+    ^bb0(%arg0: vector<128xf32>, %arg1: vector<128xf32>):
+
+    // CHECK-PROP: %[[SINGLE_VAL:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>) -> vector<32xf32>
+    %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32>
+
+    // CHECK-PROP: vector.yield %[[SINGLE_VAL]] : vector<32xf32>
+    vector.yield %2, %2 : vector<32xf32>, vector<32xf32>
+  }
+
+  // CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32>
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}


        


More information about the Mlir-commits mailing list