[Mlir-commits] [mlir] d7d6443 - [mlir][vector] Avoid creating duplicate output in warpOp

Thomas Raoux llvmlistbot at llvm.org
Mon Jul 11 08:42:08 PDT 2022


Author: Thomas Raoux
Date: 2022-07-11T15:37:50Z
New Revision: d7d6443d501839ef806f9dc872900451d7b41927

URL: https://github.com/llvm/llvm-project/commit/d7d6443d501839ef806f9dc872900451d7b41927
DIFF: https://github.com/llvm/llvm-project/commit/d7d6443d501839ef806f9dc872900451d7b41927.diff

LOG: [mlir][vector] Avoid creating duplicate output in warpOp

Prevent creating multiple output for the same Value when distributing
operations out of WarpExecuteOnLane0Op. This avoid creating combinatory
explosion of outputs.

Differential Revision: https://reviews.llvm.org/D129465

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2b9635835d7b1..57fa863320906 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -14,7 +14,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/SideEffectUtils.h"
-
+#include "llvm/ADT/SetVector.h"
 #include <utility>
 
 using namespace mlir;
@@ -165,19 +165,34 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
 }
 
 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
   SmallVector<Type> types(warpOp.getResultTypes().begin(),
                           warpOp.getResultTypes().end());
-  types.append(newReturnTypes.begin(), newReturnTypes.end());
   auto yield = cast<vector::YieldOp>(
       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  SmallVector<Value> yieldValues(yield.getOperands().begin(),
-                                 yield.getOperands().end());
-  yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand.value() == std::get<0>(newRet)) {
+          indices.push_back(yieldOperand.index());
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues, types);
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
   rewriter.replaceOp(warpOp,
                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
   return newWarpOp;
@@ -273,14 +288,15 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
   assert(writeOp->getParentOp() == warpOp &&
          "write must be nested immediately under warp");
   OpBuilder::InsertionGuard g(rewriter);
+  SmallVector<size_t> newRetIndices;
   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
       rewriter, warpOp, ValueRange{{writeOp.getVector()}},
-      TypeRange{targetType});
+      TypeRange{targetType}, newRetIndices);
   rewriter.setInsertionPointAfter(newWarpOp);
   auto newWriteOp =
       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
   rewriter.eraseOp(writeOp);
-  newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
   return newWriteOp;
 }
 
@@ -387,8 +403,9 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
 
     SmallVector<Value> yieldValues = {writeOp.getVector()};
     SmallVector<Type> retTypes = {vecType};
+    SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, yieldValues, retTypes);
+        rewriter, warpOp, yieldValues, retTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
 
     // Create a second warp op that contains only writeOp.
@@ -398,8 +415,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     rewriter.setInsertionPointToStart(&body);
     auto newWriteOp =
         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
-    newWriteOp.getVectorMutable().assign(
-        newWarpOp.getResult(newWarpOp.getNumResults() - 1));
+    newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
     rewriter.eraseOp(writeOp);
     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
     return success();
@@ -489,14 +505,14 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
       retTypes.push_back(targetType);
       yieldValues.push_back(operand.get());
     }
-    unsigned numResults = warpOp.getNumResults();
+    SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, yieldValues, retTypes);
+        rewriter, warpOp, yieldValues, retTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
                                    elementWise->getOperands().end());
     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
-      newOperands[i] = newWarpOp.getResult(i + numResults);
+      newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
     }
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
@@ -653,12 +669,13 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = broadcastOp.getLoc();
     auto destVecType =
         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+    SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, {broadcastOp.getSource()},
-        {broadcastOp.getSource().getType()});
+        {broadcastOp.getSource().getType()}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
-        loc, destVecType, newWarpOp->getResults().back());
+        loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
     return success();
   }
@@ -814,12 +831,12 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
     SmallVector<Value> yieldValues = {reductionOp.getVector()};
     SmallVector<Type> retTypes = {
         VectorType::get({numElements}, reductionOp.getType())};
-    unsigned numResults = warpOp.getNumResults();
+    SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, yieldValues, retTypes);
+        rewriter, warpOp, yieldValues, retTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
 
-    Value laneValVec = newWarpOp.getResult(numResults);
+    Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
     // First reduce on a single thread.
     Value perLaneReduction = rewriter.create<vector::ReductionOp>(
         reductionOp.getLoc(), reductionOp.getKind(), laneValVec);

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 82f6299634578..4a04f988be979 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -545,3 +545,20 @@ func.func @vector_reduction_large(%laneid: index) -> (f32) {
   }
   return %r : f32
 }
+
+// -----
+
+// CHECK-PROP-LABEL:   func @warp_duplicate_yield(
+func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>) {
+  //   CHECK-PROP: %{{.*}}:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>)
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>, vector<1xf32>) {
+    %2 = "some_def"() : () -> (vector<32xf32>)
+    %3 = "some_def"() : () -> (vector<32xf32>)
+    %4 = arith.addf %2, %3 : vector<32xf32>
+    %5 = arith.addf %2, %2 : vector<32xf32>
+// CHECK-PROP-NOT:   arith.addf
+//     CHECK-PROP:   vector.yield %{{.*}}, %{{.*}} : vector<32xf32>, vector<32xf32>
+    vector.yield %4, %5 : vector<32xf32>, vector<32xf32>
+  }
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}


        


More information about the Mlir-commits mailing list