[Mlir-commits] [mlir] ed0288f - [mlir][vector] Add patterns for vector distribution

Thomas Raoux llvmlistbot at llvm.org
Fri Jun 10 10:47:00 PDT 2022


Author: Thomas Raoux
Date: 2022-06-10T17:46:51Z
New Revision: ed0288f7c4a5b3bf486803a11a71cbff35aa0111

URL: https://github.com/llvm/llvm-project/commit/ed0288f7c4a5b3bf486803a11a71cbff35aa0111
DIFF: https://github.com/llvm/llvm-project/commit/ed0288f7c4a5b3bf486803a11a71cbff35aa0111.diff

LOG: [mlir][vector] Add patterns for vector distribution

Add pattern to hoist scalar code outside of warp distribute region as
those cannot be distributed and we would want to execute them on all
the lanes.
Add patterns to distribute transfer_write ops. Those operations can be
distributed in different ways and it is control by user.

Differential Revision: https://reviews.llvm.org/D127152

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 06ca02483be54..b95b527d0639c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -39,6 +39,32 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
     RewritePatternSet &patterns,
     const WarpExecuteOnLane0LoweringOptions &options);
 
+using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
+
+/// Distribute transfer_write ops based on the affine map returned by
+/// `distributionMapFn`.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%id){
+///   ...
+///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
+///   vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+///   ...
+///   vector.yield %v : vector<32xf32>
+/// }
+/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+void populateDistributeTransferWriteOpPatterns(
+    RewritePatternSet &patterns, DistributionMapFn distributionMapFn);
+
+/// Move scalar operations with no dependency on the warp op outside of the
+/// region.
+void moveScalarUniformCode(WarpExecuteOnLane0Op op);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 933d572023594..586604f6fd6c3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Transforms/SideEffectUtils.h"
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -93,8 +95,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     if (resultType == val.getType()) {
       // Result type and yielded value type are the same. This is a broadcast.
       // E.g.:
-      // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) {
-      //   vector_ext.yield %cst : f32
+      // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+      //   vector.yield %cst : f32
       // }
       // Both types are f32. The constant %cst is broadcasted to all lanes.
       // This is described in more detail in the documentation of the op.
@@ -131,6 +133,54 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   return success();
 }
 
+/// Helper to create a new WarpExecuteOnLane0Op with 
diff erent signature.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.updateRootInPlace(
+      yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  types.append(newReturnTypes.begin(), newReturnTypes.end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  SmallVector<Value> yieldValues(yield.getOperands().begin(),
+                                 yield.getOperands().end());
+  yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues, types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+/// Helper to know if an op can be hoisted out of the region.
+static bool canBeHoisted(Operation *op,
+                         function_ref<bool(Value)> definedOutside) {
+  return llvm::all_of(op->getOperands(), definedOutside) &&
+         isSideEffectFree(op) && op->getNumRegions() == 0;
+}
+
 namespace {
 
 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -149,6 +199,157 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
   const WarpExecuteOnLane0LoweringOptions &options;
 };
 
+/// Distribute transfer_write ops based on the affine map returned by
+/// `distributionMapFn`.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%id){
+///   ...
+///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
+///   vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+///   ...
+///   vector.yield %v : vector<32xf32>
+/// }
+/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
+  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
+                      PatternBenefit b = 1)
+      : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
+        distributionMapFn(fn) {}
+
+  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
+  /// are multiples of the distribution ratio are supported at the moment.
+  LogicalResult tryDistributeOp(RewriterBase &rewriter,
+                                vector::TransferWriteOp writeOp,
+                                WarpExecuteOnLane0Op warpOp) const {
+    AffineMap map = distributionMapFn(writeOp);
+    SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
+                                     writeOp.getVectorType().getShape().end());
+    assert(map.getNumResults() == 1 &&
+           "multi-dim distribution not implemented yet");
+    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+      unsigned position = map.getDimPosition(i);
+      if (targetShape[position] % warpOp.getWarpSize() != 0)
+        return failure();
+      targetShape[position] = targetShape[position] / warpOp.getWarpSize();
+    }
+    VectorType targetType =
+        VectorType::get(targetShape, writeOp.getVectorType().getElementType());
+
+    SmallVector<Value> yieldValues = {writeOp.getVector()};
+    SmallVector<Type> retTypes = {targetType};
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, yieldValues, retTypes);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    // Move op outside of region: Insert clone at the insertion point and delete
+    // the old op.
+    auto newWriteOp =
+        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+    rewriter.eraseOp(writeOp);
+
+    rewriter.setInsertionPoint(newWriteOp);
+    AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
+    Location loc = newWriteOp.getLoc();
+    SmallVector<Value> indices(newWriteOp.getIndices().begin(),
+                               newWriteOp.getIndices().end());
+    for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+      AffineExpr d0, d1;
+      bindDims(newWarpOp.getContext(), d0, d1);
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+      auto scale =
+          getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
+      indices[indexPos] =
+          makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
+                                  {indices[indexPos], newWarpOp.getLaneid()});
+    }
+    newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+    newWriteOp.getIndicesMutable().assign(indices);
+
+    return success();
+  }
+
+  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
+  LogicalResult tryExtractOp(RewriterBase &rewriter,
+                             vector::TransferWriteOp writeOp,
+                             WarpExecuteOnLane0Op warpOp) const {
+    Location loc = writeOp.getLoc();
+    VectorType vecType = writeOp.getVectorType();
+
+    // Only vector<1x> is supported at the moment.
+    if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
+      return failure();
+
+    // Do not process warp ops that contain only TransferWriteOps.
+    if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
+          return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
+        }))
+      return failure();
+
+    SmallVector<Value> yieldValues = {writeOp.getVector()};
+    SmallVector<Type> retTypes = {vecType};
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, yieldValues, retTypes);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    // Create a second warp op that contains only writeOp.
+    auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+        loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
+    Block &body = secondWarpOp.getBodyRegion().front();
+    rewriter.setInsertionPointToStart(&body);
+    auto newWriteOp =
+        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+    newWriteOp.getVectorMutable().assign(
+        newWarpOp.getResult(newWarpOp.getNumResults() - 1));
+    rewriter.eraseOp(writeOp);
+    rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
+    return success();
+  }
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // Ops with mask not supported yet.
+    if (writeOp.getMask())
+      return failure();
+
+    auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
+    if (!warpOp)
+      return failure();
+
+    // There must be no op with a side effect after writeOp.
+    Operation *nextOp = writeOp.getOperation();
+    while ((nextOp = nextOp->getNextNode()))
+      if (!isSideEffectFree(nextOp))
+        return failure();
+
+    if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
+          return writeOp.getVector() == value ||
+                 warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
+    if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
+      return success();
+
+    if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
+      return success();
+
+    return failure();
+  }
+
+private:
+  DistributionMapFn distributionMapFn;
+};
+
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -156,3 +357,36 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
     const WarpExecuteOnLane0LoweringOptions &options) {
   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
 }
+
+void mlir::vector::populateDistributeTransferWriteOpPatterns(
+    RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
+  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
+}
+
+void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
+  Block *body = warpOp.getBody();
+
+  // Keep track of the ops we want to hoist.
+  llvm::SmallSetVector<Operation *, 8> opsToMove;
+
+  // Helper to check if a value is or will be defined outside of the region.
+  auto isDefinedOutsideOfBody = [&](Value value) {
+    auto *definingOp = value.getDefiningOp();
+    return (definingOp && opsToMove.count(definingOp)) ||
+           warpOp.isDefinedOutsideOfRegion(value);
+  };
+
+  // Do not use walk here, as we do not want to go into nested regions and hoist
+  // operations from there.
+  for (auto &op : body->without_terminator()) {
+    bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
+      return result.getType().isa<VectorType>();
+    });
+    if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
+      opsToMove.insert(&op);
+  }
+
+  // Move all the ops marked as uniform outside of the region.
+  for (Operation *op : opsToMove)
+    op->moveBefore(warpOp);
+}

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index cba8f051ceba9..dc4dfee861fb7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,4 +1,6 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
 
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
@@ -52,3 +54,76 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index,
   "some_use"(%r#1) : (vector<2xf32>) -> ()
   return
 }
+
+// -----
+
+// CHECK-D-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 2 + 32)>
+
+// CHECK-ALL-LABEL: func @warp(
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: vector.warp_execute_on_lane_0
+
+//     CHECK-D: %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>, vector<1xf32>) {
+//     CHECK-D:   arith.addf {{.*}} : vector<32xf32>
+//     CHECK-D:   arith.addf {{.*}} : vector<64xf32>
+//     CHECK-D:   vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, vector<32xf32>
+// CHECK-D-DAG: vector.transfer_write %[[R]]#1, %{{.*}}[%{{.*}}] {in_bounds = [true]} : vector<1xf32>, memref<128xf32
+// CHECK-D-DAG: %[[ID1:.*]] = affine.apply #[[MAP1]]()[%{{.*}}]
+// CHECK-D-DAG: vector.transfer_write %[[R]]#0, %2[%[[ID1]]] {in_bounds = [true]} : vector<2xf32>, memref<128xf32
+
+// CHECK-ALL-NOT: vector.warp_execute_on_lane_0
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<1xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<2xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<1xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<2xf32>
+
+#map0 =  affine_map<(d0)[s0] -> (d0 + s0)>
+func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>,
+           %arg3: memref<1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %sb = memref.subview %arg2[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %sc = memref.subview %arg3[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %2 = vector.transfer_read %sa[%c0], %cst : memref<128xf32, #map0>, vector<32xf32>
+    %3 = vector.transfer_read %sa[%c32], %cst : memref<128xf32, #map0>, vector<32xf32>
+    %4 = vector.transfer_read %sb[%c0], %cst : memref<128xf32, #map0>, vector<64xf32>
+    %5 = vector.transfer_read %sb[%c32], %cst : memref<128xf32, #map0>, vector<64xf32>
+    %6 = arith.addf %2, %3 : vector<32xf32>
+    %7 = arith.addf %4, %5 : vector<64xf32>
+    vector.transfer_write %6, %sc[%c0] : vector<32xf32>, memref<128xf32, #map0>
+    vector.transfer_write %7, %sc[%c32] : vector<64xf32>, memref<128xf32, #map0>
+  }
+  return
+}
+
+// -----
+
+// CHECK-D-LABEL: func @warp_extract(
+//       CHECK-D:   %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     vector.yield %{{.*}} : vector<1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
+//       CHECK-D:   }
+
+#map2 =  affine_map<(d0)[s0] -> (d0 + s0)>
+
+func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<1xf32>)
+    vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
+  }
+  return
+}
\ No newline at end of file

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f216ba8e80eb9..e1ffddc5f0687 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -809,7 +809,8 @@ struct TestVectorDistribution
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
+    registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
+                    AffineDialect>();
   }
 
   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
@@ -825,8 +826,43 @@ struct TestVectorDistribution
       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
       llvm::cl::init(false)};
 
+  Option<bool> distributeTransferWriteOps{
+      *this, "distribute-transfer-write",
+      llvm::cl::desc("Test distribution of transfer write"),
+      llvm::cl::init(false)};
+
+  Option<bool> hoistUniform{*this, "hoist-uniform",
+                            llvm::cl::desc("Test hoist uniform"),
+                            llvm::cl::init(false)};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
+
+    getOperation().walk([&](Operation *op) {
+      if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
+        if (hoistUniform) {
+          moveScalarUniformCode(warpOp);
+        }
+        WalkResult::interrupt();
+      }
+    });
+    MLIRContext *ctx = &getContext();
+    if (distributeTransferWriteOps) {
+      auto distributionFn = [](vector::TransferWriteOp writeOp) {
+        // Create a map (d0, d1) -> (d1) to distribute along the inner
+        // dimension. Once we support n-d distribution we can add more
+        // complex cases.
+        int64_t vecRank = writeOp.getVectorType().getRank();
+        OpBuilder builder(writeOp.getContext());
+        auto map =
+            AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+        return map;
+      };
+      RewritePatternSet patterns(ctx);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    }
+
     WarpExecuteOnLane0LoweringOptions options;
     options.warpAllocationFn = allocateGlobalSharedMemory;
     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,


        


More information about the Mlir-commits mailing list