[Mlir-commits] [mlir] 76cf33d - [mlir][vector] Add patterns to ppropagate vector distribution

Thomas Raoux llvmlistbot at llvm.org
Mon Jun 13 22:28:19 PDT 2022


Author: Thomas Raoux
Date: 2022-06-14T05:26:10Z
New Revision: 76cf33dab2d8846322f35d4065eec1562b563f45

URL: https://github.com/llvm/llvm-project/commit/76cf33dab2d8846322f35d4065eec1562b563f45
DIFF: https://github.com/llvm/llvm-project/commit/76cf33dab2d8846322f35d4065eec1562b563f45.diff

LOG: [mlir][vector] Add patterns to ppropagate vector distribution

Add patterns to propagate vector distribution and remove dead
arguments. This handles propagation for several vector operations.

recommit after minor bug fix.

Differential Revision: https://reviews.llvm.org/D127167

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir
    mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index b95b527d0639c..5af0da2f8528c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -65,6 +65,10 @@ void populateDistributeTransferWriteOpPatterns(
 /// region.
 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
 
+/// Collect patterns to propagate warp distribution.
+void populatePropagateWarpVectorDistributionPatterns(
+    RewritePatternSet &pattern);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 586604f6fd6c3..662f62f50376e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/SideEffectUtils.h"
 
 using namespace mlir;
@@ -181,6 +182,60 @@ static bool canBeHoisted(Operation *op,
          isSideEffectFree(op) && op->getNumRegions() == 0;
 }
 
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                                std::function<bool(Operation *)> fn) {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+// Clones `op` into a new operation that takes `operands` and returns
+// `resultTypes`.
+static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
+                                              Location loc, Operation *op,
+                                              ArrayRef<Value> operands,
+                                              ArrayRef<Type> resultTypes) {
+  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
+                     op->getAttrs());
+  return rewriter.create(res);
+}
+
+/// Currently the distribution map is implicit based on the vector shape. In the
+/// future it will be part of the op.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
+///   ...
+///   vector.yield %3 : vector<32x16x64xf32>
+/// }
+/// ```
+/// Would have an implicit map of:
+/// `(d0, d1, d2) -> (d0, d2)`
+static AffineMap calculateImplicitMap(Value yield, Value ret) {
+  auto srcType = yield.getType().cast<VectorType>();
+  auto dstType = ret.getType().cast<VectorType>();
+  SmallVector<AffineExpr> perm;
+  // Check which dimensions of the yield value are 
diff erent than the dimensions
+  // of the result to know the distributed dimensions. Then associate each
+  // distributed dimension to an ID in order.
+  for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
+    if (srcType.getDimSize(i) != dstType.getDimSize(i))
+      perm.push_back(getAffineDimExpr(i, yield.getContext()));
+  }
+  auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
+  return map;
+}
+
 namespace {
 
 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -350,6 +405,323 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
   DistributionMapFn distributionMapFn;
 };
 
+/// Sink out elementwise op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %3 = arith.addf %1, %2 : vector<32xf32>
+///   vector.yield %3 : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
+/// vector<1xf32>, vector<1xf32>) {
+///   ...
+///   %4 = arith.addf %2, %3 : vector<32xf32>
+///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
+///   vector<32xf32>
+/// }
+/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
+struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
+      return OpTrait::hasElementwiseMappableTraits(op);
+    });
+    if (!yieldOperand)
+      return failure();
+    Operation *elementWise = yieldOperand->get().getDefiningOp();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+    Value distributedVal = warpOp.getResult(operandIndex);
+    SmallVector<Value> yieldValues;
+    SmallVector<Type> retTypes;
+    Location loc = warpOp.getLoc();
+    for (OpOperand &operand : elementWise->getOpOperands()) {
+      Type targetType;
+      if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
+        // If the result type is a vector, the operands must also be vectors.
+        auto operandType = operand.get().getType().cast<VectorType>();
+        targetType =
+            VectorType::get(vecType.getShape(), operandType.getElementType());
+      } else {
+        auto operandType = operand.get().getType();
+        assert(!operandType.isa<VectorType>() &&
+               "unexpected yield of vector from op with scalar result type");
+        targetType = operandType;
+      }
+      retTypes.push_back(targetType);
+      yieldValues.push_back(operand.get());
+    }
+    unsigned numResults = warpOp.getNumResults();
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, yieldValues, retTypes);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newOperands(elementWise->getOperands().begin(),
+                                   elementWise->getOperands().end());
+    for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
+      newOperands[i] = newWarpOp.getResult(i + numResults);
+    }
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Operation *newOp = cloneOpWithOperandsAndTypes(
+        rewriter, loc, elementWise, newOperands,
+        {newWarpOp.getResult(operandIndex).getType()});
+    newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
+    return success();
+  }
+};
+
+/// Sink out transfer_read op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+//    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
+//    vector<32xf32>
+///   vector.yield %2 : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
+/// vector<1xf32>, vector<1xf32>) {
+///   ...
+///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
+///   vector<32xf32> vector.yield %2 : vector<32xf32>
+/// }
+/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
+struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
+    if (!operand)
+      return failure();
+    auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
+    unsigned operandIndex = operand->getOperandNumber();
+    Value distributedVal = warpOp.getResult(operandIndex);
+
+    SmallVector<Value, 4> indices(read.getIndices().begin(),
+                                  read.getIndices().end());
+    AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
+    AffineMap indexMap = map.compose(read.getPermutationMap());
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(warpOp);
+    for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+      AffineExpr d0, d1;
+      bindDims(read.getContext(), d0, d1);
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+      int64_t scale =
+          distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
+      indices[indexPos] =
+          makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
+                                  {indices[indexPos], warpOp.getLaneid()});
+    }
+    Value newRead = rewriter.create<vector::TransferReadOp>(
+        read.getLoc(), distributedVal.getType(), read.getSource(), indices,
+        read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+        read.getInBoundsAttr());
+    distributedVal.replaceAllUsesWith(newRead);
+    return success();
+  }
+};
+
+/// Remove any result that has no use along with the matching yieldOp operand.
+// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
+struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Type> resultTypes;
+    SmallVector<Value> yieldValues;
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    for (OpResult result : warpOp.getResults()) {
+      if (result.use_empty())
+        continue;
+      resultTypes.push_back(result.getType());
+      yieldValues.push_back(yield.getOperand(result.getResultNumber()));
+    }
+    if (yield.getNumOperands() == yieldValues.size())
+      return failure();
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, yieldValues, resultTypes);
+    unsigned resultIndex = 0;
+    for (OpResult result : warpOp.getResults()) {
+      if (result.use_empty())
+        continue;
+      result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
+    }
+    rewriter.eraseOp(warpOp);
+    return success();
+  }
+};
+
+// If an operand is directly yielded out of the region we can forward it
+// directly and it doesn't need to go through the region.
+struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Type> resultTypes;
+    SmallVector<Value> yieldValues;
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Value valForwarded;
+    unsigned resultIndex;
+    for (OpOperand &operand : yield->getOpOperands()) {
+      Value result = warpOp.getResult(operand.getOperandNumber());
+      if (result.use_empty())
+        continue;
+
+      // Assume all the values coming from above are uniform.
+      if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
+        if (result.getType() != operand.get().getType())
+          continue;
+        valForwarded = operand.get();
+        resultIndex = operand.getOperandNumber();
+        break;
+      }
+      auto arg = operand.get().dyn_cast<BlockArgument>();
+      if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
+        continue;
+      Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
+      if (result.getType() != warpOperand.getType())
+        continue;
+      valForwarded = warpOperand;
+      resultIndex = operand.getOperandNumber();
+      break;
+    }
+    if (!valForwarded)
+      return failure();
+    warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
+    return success();
+  }
+};
+
+struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
+    Location loc = broadcastOp.getLoc();
+    auto destVecType =
+        warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {broadcastOp.getSource()},
+        {broadcastOp.getSource().getType()});
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value broadcasted = rewriter.create<vector::BroadcastOp>(
+        loc, destVecType, newWarpOp->getResults().back());
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
+
+    return success();
+  }
+};
+
+/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
+/// the scf.ForOp is the last operation in the region so that it doesn't change
+/// the order of execution. This creates a new scf.for region after the
+/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
+/// WarpExecuteOnLane0Op region. Example:
+/// ```
+/// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
+///   ...
+///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
+///   -> (vector<128xf32>) {
+///     ...
+///     scf.yield %r : vector<128xf32>
+///   }
+///   vector.yield %v1 : vector<128xf32>
+/// }
+/// ```
+/// To:
+/// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
+///   ...
+///   vector.yield %v : vector<128xf32>
+/// }
+/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
+///   -> (vector<4xf32>) {
+///     %iw = vector.warp_execute_on_lane_0(%laneid)
+///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
+///     ^bb0(%arg: vector<128xf32>):
+///       ...
+///       vector.yield %ir : vector<128xf32>
+///     }
+///     scf.yield %iw : vector<4xf32>
+///  }
+/// ```
+struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    // Only pick up forOp if it is the last op in the region.
+    Operation *lastNode = yield->getPrevNode();
+    auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
+    if (!forOp)
+      return failure();
+    SmallVector<Value> newOperands;
+    SmallVector<unsigned> resultIdx;
+    // Collect all the outputs coming from the forOp.
+    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+        continue;
+      auto forResult = yieldOperand.get().cast<OpResult>();
+      newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
+      yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
+      resultIdx.push_back(yieldOperand.getOperandNumber());
+    }
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(warpOp);
+    // Create a new for op outside the region with a WarpExecuteOnLane0Op region
+    // inside.
+    auto newForOp = rewriter.create<scf::ForOp>(
+        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+        forOp.getStep(), newOperands);
+    rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
+    auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
+        warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
+        warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
+        forOp.getResultTypes());
+
+    SmallVector<Value> argMapping;
+    argMapping.push_back(newForOp.getInductionVar());
+    for (Value args : innerWarp.getBody()->getArguments()) {
+      argMapping.push_back(args);
+    }
+    SmallVector<Value> yieldOperands;
+    for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+      yieldOperands.push_back(operand);
+    rewriter.eraseOp(forOp.getBody()->getTerminator());
+    rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+    rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
+    rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
+    rewriter.setInsertionPointAfter(innerWarp);
+    rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+    rewriter.eraseOp(forOp);
+    // Replace the warpOp result coming from the original ForOp.
+    for (const auto &res : llvm::enumerate(resultIdx)) {
+      warpOp.getResult(res.value())
+          .replaceAllUsesWith(newForOp.getResult(res.index()));
+      newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -363,6 +735,13 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
 }
 
+void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
+               WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
+      patterns.getContext());
+}
+
 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
   Block *body = warpOp.getBody();
 

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dc4dfee861fb7..b57791ad04a19 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
 
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
@@ -126,4 +127,310 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
     vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
   }
   return
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-PROP-LABEL:   func @warp_dead_result(
+func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
+  // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+  %r:3 = vector.warp_execute_on_lane_0(%laneid)[32] ->
+    (vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+    %2 = "some_def"() : () -> (vector<32xf32>)
+    %3 = "some_def"() : () -> (vector<32xf32>)
+    %4 = "some_def"() : () -> (vector<32xf32>)
+  // CHECK-PROP:   vector.yield %{{.*}} : vector<32xf32>
+    vector.yield %2, %3, %4 : vector<32xf32>, vector<32xf32>, vector<32xf32>
+  }
+  // CHECK-PROP: return %[[R]] : vector<1xf32>
+  return %r#1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL:   func @warp_propagate_operand(
+//  CHECK-PROP-SAME:   %[[ID:.*]]: index, %[[V:.*]]: vector<4xf32>)
+func.func @warp_propagate_operand(%laneid: index, %v0: vector<4xf32>)
+  -> (vector<4xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32]
+     args(%v0 : vector<4xf32>) -> (vector<4xf32>) {
+     ^bb0(%arg0 : vector<128xf32>) :
+    vector.yield %arg0 : vector<128xf32>
+  }
+  // CHECK-PROP: return %[[V]] : vector<4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-PROP-LABEL:   func @warp_propagate_elementwise(
+func.func @warp_propagate_elementwise(%laneid: index, %dest: memref<1024xf32>) {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK-PROP: %[[R:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>, vector<2xf32>, vector<2xf32>)
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] ->
+    (vector<1xf32>, vector<2xf32>) {
+    // CHECK-PROP: %[[V0:.*]] = "some_def"() : () -> vector<32xf32>
+    // CHECK-PROP: %[[V1:.*]] = "some_def"() : () -> vector<32xf32>
+    // CHECK-PROP: %[[V2:.*]] = "some_def"() : () -> vector<64xf32>
+    // CHECK-PROP: %[[V3:.*]] = "some_def"() : () -> vector<64xf32>
+    // CHECK-PROP: vector.yield %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<32xf32>, vector<32xf32>, vector<64xf32>, vector<64xf32>
+    %2 = "some_def"() : () -> (vector<32xf32>)
+    %3 = "some_def"() : () -> (vector<32xf32>)
+    %4 = "some_def"() : () -> (vector<64xf32>)
+    %5 = "some_def"() : () -> (vector<64xf32>)
+    %6 = arith.addf %2, %3 : vector<32xf32>
+    %7 = arith.addf %4, %5 : vector<64xf32>
+    vector.yield %6, %7 : vector<32xf32>, vector<64xf32>
+  }
+  // CHECK-PROP: %[[A0:.*]] = arith.addf %[[R]]#2, %[[R]]#3 : vector<2xf32>
+  // CHECK-PROP: %[[A1:.*]] = arith.addf %[[R]]#0, %[[R]]#1 : vector<1xf32>
+  %id2 = affine.apply #map0()[%laneid]
+  // CHECK-PROP: vector.transfer_write %[[A1]], {{.*}} : vector<1xf32>, memref<1024xf32>
+  // CHECK-PROP: vector.transfer_write %[[A0]], {{.*}} : vector<2xf32>, memref<1024xf32>
+  vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
+  vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32>
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_propagate_scalar_arith(
+//       CHECK-PROP:   %[[r:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} {
+//       CHECK-PROP:     %[[some_def0:.*]] = "some_def"
+//       CHECK-PROP:     %[[some_def1:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[some_def0]], %[[some_def1]]
+//       CHECK-PROP:   }
+//       CHECK-PROP:   arith.addf %[[r]]#0, %[[r]]#1 : f32
+func.func @warp_propagate_scalar_arith(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (f32)
+    %1 = "some_def"() : () -> (f32)
+    %2 = arith.addf %0, %1 : f32
+    vector.yield %2 : f32
+  }
+  vector.print %r : f32
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_propagate_cast(
+//   CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//       CHECK-PROP:   %[[result:.*]] = arith.sitofp %{{.*}} : i32 to f32
+//       CHECK-PROP:   return %[[result]]
+func.func @warp_propagate_cast(%laneid : index, %i : i32) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %casted = arith.sitofp %i : i32 to f32
+    vector.yield %casted : f32
+  }
+  return %r : f32
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 2)>
+
+//  CHECK-PROP-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-PROP:   func @warp_propagate_read
+//  CHECK-PROP-SAME:     (%[[ID:.*]]: index
+func.func @warp_propagate_read(%laneid: index, %src: memref<1024xf32>, %dest: memref<1024xf32>) {
+// CHECK-PROP-NOT: warp_execute_on_lane_0
+// CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[ID]]], %{{.*}} : memref<1024xf32>, vector<1xf32>
+// CHECK-PROP-DAG: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
+// CHECK-PROP-DAG: %[[R1:.*]] = vector.transfer_read %arg1[%[[ID2]]], %{{.*}} : memref<1024xf32>, vector<2xf32>
+// CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1xf32>, memref<1024xf32>
+// CHECK-PROP: vector.transfer_write %[[R1]], {{.*}} : vector<2xf32>, memref<1024xf32>
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] ->(vector<1xf32>, vector<2xf32>) {
+    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32>
+    %3 = vector.transfer_read %src[%c32], %cst : memref<1024xf32>, vector<64xf32>
+    vector.yield %2, %3 : vector<32xf32>, vector<64xf32>
+  }
+  %id2 = affine.apply #map0()[%laneid]
+  vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
+  vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32>
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @fold_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[some_def]] : vector<1xf32>
+//       CHECK-PROP:   vector.print %[[r]] : vector<1xf32>
+func.func @fold_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+    %0 = "some_def"() : () -> (vector<1xf32>)
+    %1 = vector.broadcast %0 : vector<1xf32> to vector<32xf32>
+    vector.yield %1 : vector<32xf32>
+  }
+  vector.print %r : vector<1xf32>
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @extract_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[some_def]] : vector<1xf32>
+//       CHECK-PROP:   %[[broadcasted:.*]] = vector.broadcast %[[r]] : vector<1xf32> to vector<2xf32>
+//       CHECK-PROP:   vector.print %[[broadcasted]] : vector<2xf32>
+func.func @extract_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<1xf32>)
+    %1 = vector.broadcast %0 : vector<1xf32> to vector<64xf32>
+    vector.yield %1 : vector<64xf32>
+  }
+  vector.print %r : vector<2xf32>
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @extract_scalar_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (f32)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[some_def]] : f32
+//       CHECK-PROP:   %[[broadcasted:.*]] = vector.broadcast %[[r]] : f32 to vector<2xf32>
+//       CHECK-PROP:   vector.print %[[broadcasted]] : vector<2xf32>
+func.func @extract_scalar_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (f32)
+    %1 = vector.broadcast %0 : f32 to vector<64xf32>
+    vector.yield %1 : vector<64xf32>
+  }
+  vector.print %r : vector<2xf32>
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL:   func @warp_scf_for(
+// CHECK-PROP: %[[INI:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>) {
+// CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:   vector.yield %[[INI1]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]) -> (vector<4xf32>) {
+// CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]] : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:    ^bb0(%[[ARG:.*]]: vector<128xf32>):
+// CHECK-PROP:      %[[ACC:.*]] = "some_def"(%[[ARG]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP:      vector.yield %[[ACC]] : vector<128xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   scf.yield %[[W]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> ()
+func.func @warp_scf_for(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    vector.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-PROP-LABEL:   func @warp_scf_for_swap(
+// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:   %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:   vector.yield %[[INI1]], %[[INI2]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG1:.*]] = %[[INI]]#0, %[[FARG2:.*]] = %[[INI]]#1) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG1]], %[[FARG2]] : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP:    ^bb0(%[[ARG1:.*]]: vector<128xf32>, %[[ARG2:.*]]: vector<128xf32>):
+// CHECK-PROP:      %[[ACC1:.*]] = "some_def"(%[[ARG1]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP:      %[[ACC2:.*]] = "some_def"(%[[ARG2]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP:      vector.yield %[[ACC2]], %[[ACC1]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   scf.yield %[[W]]#0, %[[W]]#1 : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+// CHECK-PROP: "some_use"(%[[F]]#1) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_swap(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0:2 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>) {
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %ini2 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2) -> (vector<128xf32>, vector<128xf32>) {
+      %acc1 = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      %acc2 = "some_def"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc2, %acc1 : vector<128xf32>, vector<128xf32>
+    }
+    vector.yield %3#0, %3#1 : vector<128xf32>, vector<128xf32>
+  }
+  "some_use"(%0#0) : (vector<4xf32>) -> ()
+  "some_use"(%0#1) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 * 4)>
+#map1 = affine_map<()[s0] -> (s0 * 128 + 128)>
+#map2 = affine_map<()[s0] -> (s0 * 4 + 128)>
+
+// CHECK-PROP-LABEL:   func @warp_scf_for_multiple_yield(
+//       CHECK-PROP:   vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+//  CHECK-PROP-NEXT:     "some_def"() : () -> vector<32xf32>
+//  CHECK-PROP-NEXT:     vector.yield %{{.*}} : vector<32xf32>
+//  CHECK-PROP-NEXT:   }
+//   CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//       CHECK-PROP:   vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+//       CHECK-PROP:   vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+//       CHECK-PROP:   %{{.*}}:2 = scf.for {{.*}} -> (vector<4xf32>, vector<4xf32>) {
+//   CHECK-PROP-NOT:     vector.warp_execute_on_lane_0
+//       CHECK-PROP:     vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+//       CHECK-PROP:     vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+//       CHECK-PROP:     arith.addf {{.*}} : vector<4xf32>
+//       CHECK-PROP:     arith.addf {{.*}} : vector<4xf32>
+//       CHECK-PROP:     scf.yield {{.*}} : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP:   }
+func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+  %c256 = arith.constant 256 : index
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0:3 = vector.warp_execute_on_lane_0(%arg0)[32] ->
+  (vector<1xf32>, vector<4xf32>, vector<4xf32>) {
+    %def = "some_def"() : () -> (vector<32xf32>)
+    %r1 = vector.transfer_read %arg2[%c0], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+    %r2 = vector.transfer_read %arg2[%c128], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %r1, %arg5 = %r2)
+    -> (vector<128xf32>, vector<128xf32>) {
+      %o1 = affine.apply #map1()[%arg3]
+      %o2 = affine.apply #map2()[%arg3]
+      %4 = vector.transfer_read %arg1[%o1], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+      %5 = vector.transfer_read %arg1[%o2], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+      %6 = arith.addf %4, %arg4 : vector<128xf32>
+      %7 = arith.addf %5, %arg5 : vector<128xf32>
+      scf.yield %6, %7 : vector<128xf32>, vector<128xf32>
+    }
+    vector.yield %def, %3#0, %3#1 :  vector<32xf32>, vector<128xf32>, vector<128xf32>
+  }
+  %1 = affine.apply #map()[%arg0]
+  vector.transfer_write %0#1, %arg2[%1] {in_bounds = [true]} : vector<4xf32>, memref<?xf32>
+  %2 = affine.apply #map2()[%arg0]
+  vector.transfer_write %0#2, %arg2[%2] {in_bounds = [true]} : vector<4xf32>, memref<?xf32>
+  "some_use"(%0#0) : (vector<1xf32>) -> ()
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
index 2205079f246a4..159c677b96319 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
@@ -11,6 +11,31 @@
 // RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
+// Run the same test cases with distribution and propagation.
+// RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \
+// RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
+// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN:  -gpu-kernel-outlining \
+// RUN:  -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \
+// RUN:  -gpu-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
+// RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
+// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN:  -gpu-kernel-outlining \
+// RUN:  -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \
+// RUN:  -gpu-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN:   -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
 func.func @gpu_func(%arg1: memref<32xf32>, %arg2: memref<32xf32>) {
   %c1 = arith.constant 1 : index
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e1ffddc5f0687..cfdeb1e632e0b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -835,6 +835,10 @@ struct TestVectorDistribution
                             llvm::cl::desc("Test hoist uniform"),
                             llvm::cl::init(false)};
 
+  Option<bool> propagateDistribution{
+      *this, "propagate-distribution",
+      llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
 
@@ -862,7 +866,11 @@ struct TestVectorDistribution
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }
-
+    if (propagateDistribution) {
+      RewritePatternSet patterns(ctx);
+      vector::populatePropagateWarpVectorDistributionPatterns(patterns);
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    }
     WarpExecuteOnLane0LoweringOptions options;
     options.warpAllocationFn = allocateGlobalSharedMemory;
     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,


        


More information about the Mlir-commits mailing list