[Mlir-commits] [mlir] fa8a10a - [mlir][Vector] Refactor vector distribution and fix an issue related to non-homogenous transfer indices.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 2 02:18:36 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-02T02:18:26-07:00
New Revision: fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d

URL: https://github.com/llvm/llvm-project/commit/fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d
DIFF: https://github.com/llvm/llvm-project/commit/fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d.diff

LOG: [mlir][Vector] Refactor vector distribution and fix an issue related to non-homogenous transfer indices.

Running: `mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize -verify-each=0`.

Prior to this revision, IR resembling the following would be produced:
```
  %4 = "vector.load"(%3, %arg0) : (memref<1x32xf32, 3>, index) -> vector<1x1xf32>
```
This fails verification since it needs 2 indices to load but only 1 is provided.

Differential Revision: https://reviews.llvm.org/D133106

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8ecb8986fd9f..22cabaf979e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -10,9 +10,9 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/Transforms/SideEffectUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include <utility>
@@ -20,122 +20,115 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-static LogicalResult
-rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-                      const WarpExecuteOnLane0LoweringOptions &options) {
-  assert(warpOp.getBodyRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-  Block *warpOpBody = &warpOp.getBodyRegion().front();
-  Location loc = warpOp.getLoc();
+/// TODO: add an analysis step that determines which vector dimension should be
+/// used for distribution.
+static llvm::Optional<int64_t>
+getDistributedVectorDim(VectorType distributedVectorType) {
+  return (distributedVectorType)
+             ? llvm::Optional<int64_t>(distributedVectorType.getRank() - 1)
+             : llvm::None;
+}
 
-  // Passed all checks. Start rewriting.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
+static llvm::Optional<int64_t>
+getDistributedSize(VectorType distributedVectorType) {
+  auto dim = getDistributedVectorDim(distributedVectorType);
+  return (dim) ? llvm::Optional<int64_t>(distributedVectorType.getDimSize(*dim))
+               : llvm::None;
+}
 
-  // Create scf.if op.
-  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                                 warpOp.getLaneid(), c0);
-  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
-                                         /*withElseRegion=*/false);
-  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
-
-  // Store vectors that are defined outside of warpOp into the scratch pad
-  // buffer.
-  SmallVector<Value> bbArgReplacements;
-  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
-    Value val = it.value();
-    Value bbArg = warpOpBody->getArgument(it.index());
-
-    rewriter.setInsertionPoint(ifOp);
-    Value buffer =
-        options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
-
-    // Store arg vector into buffer.
-    rewriter.setInsertionPoint(ifOp);
-    auto vectorType = val.getType().cast<VectorType>();
-    int64_t storeSize = vectorType.getShape()[0];
-    Value storeOffset = rewriter.create<arith::MulIOp>(
-        loc, warpOp.getLaneid(),
-        rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
-    rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
-
-    // Load bbArg vector from buffer.
-    rewriter.setInsertionPointToStart(ifOp.thenBlock());
-    auto bbArgType = bbArg.getType().cast<VectorType>();
-    Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
-    bbArgReplacements.push_back(loadOp);
+namespace {
+
+/// Helper struct to create the load / store operations that permit transit
+/// through the parallel / sequential and the sequential / parallel boundaries
+/// when performing `rewriteWarpOpToScfFor`.
+///
+/// All this assumes the vector distribution occurs along the most minor
+/// distributed vector dimension.
+/// TODO: which is expected to be a multiple of the warp size ?
+/// TODO: add an analysis step that determines which vector dimension should
+/// be used for distribution.
+struct DistributedLoadStoreHelper {
+  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
+                             Value laneId, Value zero)
+      : sequentialVal(sequentialVal), distributedVal(distributedVal),
+        laneId(laneId), zero(zero) {
+    sequentialType = sequentialVal.getType();
+    distributedType = distributedVal.getType();
+    sequentialVectorType = sequentialType.dyn_cast<VectorType>();
+    distributedVectorType = distributedType.dyn_cast<VectorType>();
   }
 
-  // Insert sync after all the stores and before all the loads.
-  if (!warpOp.getArgs().empty()) {
-    rewriter.setInsertionPoint(ifOp);
-    options.warpSyncronizationFn(loc, rewriter, warpOp);
+  Value buildDistributedOffset(RewriterBase &b, Location loc) {
+    auto maybeDistributedSize = getDistributedSize(distributedVectorType);
+    assert(maybeDistributedSize &&
+           "at this point, a distributed size must be determined");
+    AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
+    return b.createOrFold<AffineApplyOp>(loc, tid * (*maybeDistributedSize),
+                                         ArrayRef<Value>{laneId});
   }
 
-  // Move body of warpOp to ifOp.
-  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
-
-  // Rewrite terminator and compute replacements of WarpOp results.
-  SmallVector<Value> replacements;
-  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
-  Location yieldLoc = yieldOp.getLoc();
-  for (const auto &it : llvm::enumerate(yieldOp.operands())) {
-    Value val = it.value();
-    Type resultType = warpOp->getResultTypes()[it.index()];
-    rewriter.setInsertionPoint(ifOp);
-    Value buffer =
-        options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
-
-    // Store yielded value into buffer.
-    rewriter.setInsertionPoint(yieldOp);
-    if (val.getType().isa<VectorType>())
-      rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
-    else
-      rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
-
-    // Load value from buffer (after warpOp).
-    rewriter.setInsertionPointAfter(ifOp);
-    if (resultType == val.getType()) {
-      // Result type and yielded value type are the same. This is a broadcast.
-      // E.g.:
-      // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
-      //   vector.yield %cst : f32
-      // }
-      // Both types are f32. The constant %cst is broadcasted to all lanes.
-      // This is described in more detail in the documentation of the op.
-      Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
-      replacements.push_back(loadOp);
-    } else {
-      auto loadedVectorType = resultType.cast<VectorType>();
-      int64_t loadSize = loadedVectorType.getShape()[0];
-
-      // loadOffset = laneid * loadSize
-      Value loadOffset = rewriter.create<arith::MulIOp>(
-          loc, warpOp.getLaneid(),
-          rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
-      Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
-                                                     buffer, loadOffset);
-      replacements.push_back(loadOp);
+  Operation *buildStore(RewriterBase &b, Location loc, Value val,
+                        Value buffer) {
+    assert((val == distributedVal || val == sequentialVal) &&
+           "Must store either the preregistered distributed or the "
+           "preregistered sequential value.");
+    // Vector case must use vector::TransferWriteOp which will later lower to
+    //   vector.store of memref.store depending on further lowerings.
+    if (val.getType().isa<VectorType>()) {
+      SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+      auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
+      assert(maybeDistributedDim && "must be able to deduce distributed dim");
+      if (val == distributedVal)
+        indices[*maybeDistributedDim] =
+            (val == distributedVal) ? buildDistributedOffset(b, loc) : zero;
+      SmallVector<bool> inBounds(indices.size(), true);
+      return b.create<vector::TransferWriteOp>(
+          loc, val, buffer, indices,
+          ArrayRef<bool>(inBounds.begin(), inBounds.end()));
     }
+    // Scalar case can directly use memref.store.
+    return b.create<memref::StoreOp>(loc, val, buffer, zero);
   }
 
-  // Insert sync after all the stores and before all the loads.
-  if (!yieldOp.operands().empty()) {
-    rewriter.setInsertionPointAfter(ifOp);
-    options.warpSyncronizationFn(loc, rewriter, warpOp);
+  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
+                  bool broadcastMode = false) {
+    // When broadcastMode is true, this is a broadcast.
+    // E.g.:
+    // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+    //   vector.yield %cst : f32
+    // }
+    // Both types are f32. The constant %cst is broadcasted to all lanes.
+    // This is described in more detail in the documentation of the op.
+    if (broadcastMode)
+      return b.create<memref::LoadOp>(loc, buffer, zero);
+
+    // Other cases must be vector atm.
+    // Vector case must use vector::TransferReadOp which will later lower to
+    //   vector.read of memref.read depending on further lowerings.
+    assert(type.isa<VectorType>() && "must be a vector type");
+    assert((type == distributedVectorType || type == sequentialVectorType) &&
+           "Must store either the preregistered distributed or the "
+           "preregistered sequential type.");
+    auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
+    assert(maybeDistributedDim && "must be able to deduce distributed dim");
+    SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+    if (type == distributedVectorType) {
+      indices[*maybeDistributedDim] = buildDistributedOffset(b, loc);
+    } else {
+      indices[*maybeDistributedDim] = zero;
+    }
+    SmallVector<bool> inBounds(indices.size(), true);
+    return b.create<vector::TransferReadOp>(
+        loc, type.cast<VectorType>(), buffer, indices,
+        ArrayRef<bool>(inBounds.begin(), inBounds.end()));
   }
 
-  // Delete terminator and add empty scf.yield.
-  rewriter.eraseOp(yieldOp);
-  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
-  rewriter.create<scf::YieldOp>(yieldLoc);
-
-  // Compute replacements for WarpOp results.
-  rewriter.replaceOp(warpOp, replacements);
+  Value sequentialVal, distributedVal, laneId, zero;
+  Type sequentialType, distributedType;
+  VectorType sequentialVectorType, distributedVectorType;
+};
 
-  return success();
-}
+} // namespace
 
 /// Helper to create a new WarpExecuteOnLane0Op with 
diff erent signature.
 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
@@ -261,6 +254,37 @@ static AffineMap calculateImplicitMap(Value yield, Value ret) {
 
 namespace {
 
+/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
+/// thread `laneId` executes the entirety of the computation.
+///
+/// After the transformation:
+///   - the IR within the scf.if op can be thought of as executing sequentially
+///     (from the point of view of threads along `laneId`).
+///   - the IR outside of the scf.if op can be thought of as executing in
+///     parallel (from the point of view of threads along `laneId`).
+///
+/// Values that need to transit through the parallel / sequential and the
+/// sequential / parallel boundaries do so via reads and writes to a temporary
+/// memory location.
+///
+/// The transformation proceeds in multiple steps:
+///   1. Create the scf.if op.
+///   2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
+///      within the scf.if to transit the values captured from above.
+///   3. Synchronize before the scf.if to ensure all writes inserted in 2. are
+///      consistent within the scf.if.
+///   4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
+///   5. Insert appropriate writes within scf.if and reads after the scf.if to
+///      transit the values returned by the op.
+///   6. Synchronize after the scf.if to ensure all writes inserted in 5. are
+///      consistent after the scf.if.
+///   7. Perform late cleanups.
+///
+/// All this assumes the vector distribution occurs along the most minor
+/// distributed vector dimension.
+/// TODO: which is expected to be a multiple of the warp size ?
+/// TODO: add an analysis step that determines which vector dimension should be
+/// used for distribution.
 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpOpToScfForPattern(MLIRContext *context,
                         const WarpExecuteOnLane0LoweringOptions &options,
@@ -270,7 +294,106 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    return rewriteWarpOpToScfFor(rewriter, warpOp, options);
+    assert(warpOp.getBodyRegion().hasOneBlock() &&
+           "expected WarpOp with single block");
+    Block *warpOpBody = &warpOp.getBodyRegion().front();
+    Location loc = warpOp.getLoc();
+
+    // Passed all checks. Start rewriting.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(warpOp);
+
+    // Step 1: Create scf.if op.
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value isLane0 = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
+    auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
+                                           /*withElseRegion=*/false);
+    rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
+
+    // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
+    // reads within the scf.if to transit the values captured from above.
+    SmallVector<Value> bbArgReplacements;
+    for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
+      Value sequentialVal = warpOpBody->getArgument(it.index());
+      Value distributedVal = it.value();
+      DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
+                                        warpOp.getLaneid(), c0);
+
+      // Create buffer before the ifOp.
+      rewriter.setInsertionPoint(ifOp);
+      Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
+                                              sequentialVal.getType());
+      // Store distributed vector into buffer, before the ifOp.
+      helper.buildStore(rewriter, loc, distributedVal, buffer);
+      // Load sequential vector from buffer, inside the ifOp.
+      rewriter.setInsertionPointToStart(ifOp.thenBlock());
+      bbArgReplacements.push_back(
+          helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
+    }
+
+    // Step 3. Insert sync after all the stores and before all the loads.
+    if (!warpOp.getArgs().empty()) {
+      rewriter.setInsertionPoint(ifOp);
+      options.warpSyncronizationFn(loc, rewriter, warpOp);
+    }
+
+    // Step 4. Move body of warpOp to ifOp.
+    rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
+
+    // Step 5. Insert appropriate writes within scf.if and reads after the
+    // scf.if to transit the values returned by the op.
+    // TODO: at this point, we can reuse the shared memory from previous
+    // buffers.
+    SmallVector<Value> replacements;
+    auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
+    Location yieldLoc = yieldOp.getLoc();
+    for (const auto &it : llvm::enumerate(yieldOp.operands())) {
+      Value sequentialVal = it.value();
+      Value distributedVal = warpOp->getResult(it.index());
+      DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
+                                        warpOp.getLaneid(), c0);
+
+      // Create buffer before the ifOp.
+      rewriter.setInsertionPoint(ifOp);
+      Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
+                                              sequentialVal.getType());
+
+      // Store yielded value into buffer, inside the ifOp, before the
+      // terminator.
+      rewriter.setInsertionPoint(yieldOp);
+      helper.buildStore(rewriter, loc, sequentialVal, buffer);
+
+      // Load distributed value from buffer, after  the warpOp.
+      rewriter.setInsertionPointAfter(ifOp);
+      bool broadcastMode =
+          (sequentialVal.getType() == distributedVal.getType());
+      // Result type and yielded value type are the same. This is a broadcast.
+      // E.g.:
+      // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+      //   vector.yield %cst : f32
+      // }
+      // Both types are f32. The constant %cst is broadcasted to all lanes.
+      // This is described in more detail in the documentation of the op.
+      replacements.push_back(helper.buildLoad(
+          rewriter, loc, distributedVal.getType(), buffer, broadcastMode));
+    }
+
+    // Step 6. Insert sync after all the stores and before all the loads.
+    if (!yieldOp.operands().empty()) {
+      rewriter.setInsertionPointAfter(ifOp);
+      options.warpSyncronizationFn(loc, rewriter, warpOp);
+    }
+
+    // Step 7. Delete terminator and add empty scf.yield.
+    rewriter.eraseOp(yieldOp);
+    rewriter.setInsertionPointToEnd(ifOp.thenBlock());
+    rewriter.create<scf::YieldOp>(yieldLoc);
+
+    // Compute replacements for WarpOp results.
+    rewriter.replaceOp(warpOp, replacements);
+
+    return success();
   }
 
 private:

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index cceb04b4d52f..8f7e867b5b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -4,7 +4,9 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
 
-
+// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-SCF-IF-DAG: #[[$TIMES8:.*]] = affine_map<()[s0] -> (s0 * 8)>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_128xf32 : memref<128xf32, 3>
@@ -16,17 +18,14 @@
 func.func @rewrite_warp_op_to_scf_if(%laneid: index,
                                 %v0: vector<4xf32>, %v1: vector<8xf32>) {
 //   CHECK-SCF-IF-DAG:   %[[c0:.*]] = arith.constant 0 : index
-//   CHECK-SCF-IF-DAG:   %[[c2:.*]] = arith.constant 2 : index
-//   CHECK-SCF-IF-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-SCF-IF-DAG:   %[[c8:.*]] = arith.constant 8 : index
 //       CHECK-SCF-IF:   %[[is_lane_0:.*]] = arith.cmpi eq, %[[laneid]], %[[c0]]
 
 //       CHECK-SCF-IF:   %[[buffer_v0:.*]] = memref.get_global @__shared_128xf32
-//       CHECK-SCF-IF:   %[[s0:.*]] = arith.muli %[[laneid]], %[[c4]]
-//       CHECK-SCF-IF:   vector.store %[[v0]], %[[buffer_v0]][%[[s0]]]
+//       CHECK-SCF-IF:   %[[s0:.*]] = affine.apply #[[$TIMES4]]()[%[[laneid]]]
+//       CHECK-SCF-IF:   vector.transfer_write %[[v0]], %[[buffer_v0]][%[[s0]]]
 //       CHECK-SCF-IF:   %[[buffer_v1:.*]] = memref.get_global @__shared_256xf32
-//       CHECK-SCF-IF:   %[[s1:.*]] = arith.muli %[[laneid]], %[[c8]]
-//       CHECK-SCF-IF:   vector.store %[[v1]], %[[buffer_v1]][%[[s1]]]
+//       CHECK-SCF-IF:   %[[s1:.*]] = affine.apply #[[$TIMES8]]()[%[[laneid]]]
+//       CHECK-SCF-IF:   vector.transfer_write %[[v1]], %[[buffer_v1]][%[[s1]]]
 
 //   CHECK-SCF-IF-DAG:   gpu.barrier
 //   CHECK-SCF-IF-DAG:   %[[buffer_def_0:.*]] = memref.get_global @__shared_32xf32
@@ -36,21 +35,21 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index,
   %r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
       args(%v0, %v1 : vector<4xf32>, vector<8xf32>) -> (vector<1xf32>, vector<2xf32>) {
     ^bb0(%arg0: vector<128xf32>, %arg1: vector<256xf32>):
-//       CHECK-SCF-IF:     %[[arg1:.*]] = vector.load %[[buffer_v1]][%[[c0]]] : memref<256xf32, 3>, vector<256xf32>
-//       CHECK-SCF-IF:     %[[arg0:.*]] = vector.load %[[buffer_v0]][%[[c0]]] : memref<128xf32, 3>, vector<128xf32>
+//       CHECK-SCF-IF:     %[[arg1:.*]] = vector.transfer_read %[[buffer_v1]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<256xf32, 3>, vector<256xf32>
+//       CHECK-SCF-IF:     %[[arg0:.*]] = vector.transfer_read %[[buffer_v0]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<128xf32, 3>, vector<128xf32>
 //       CHECK-SCF-IF:     %[[def_0:.*]] = "some_def"(%[[arg0]]) : (vector<128xf32>) -> vector<32xf32>
 //       CHECK-SCF-IF:     %[[def_1:.*]] = "some_def"(%[[arg1]]) : (vector<256xf32>) -> vector<64xf32>
     %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32>
     %3 = "some_def"(%arg1) : (vector<256xf32>) -> vector<64xf32>
-//       CHECK-SCF-IF:     vector.store %[[def_0]], %[[buffer_def_0]][%[[c0]]]
-//       CHECK-SCF-IF:     vector.store %[[def_1]], %[[buffer_def_1]][%[[c0]]]
+//       CHECK-SCF-IF:     vector.transfer_write %[[def_0]], %[[buffer_def_0]][%[[c0]]]
+//       CHECK-SCF-IF:     vector.transfer_write %[[def_1]], %[[buffer_def_1]][%[[c0]]]
     vector.yield %2, %3 : vector<32xf32>, vector<64xf32>
   }
 //       CHECK-SCF-IF:   }
 //       CHECK-SCF-IF:   gpu.barrier
-//       CHECK-SCF-IF:   %[[o1:.*]] = arith.muli %[[laneid]], %[[c2]]
-//       CHECK-SCF-IF:   %[[r1:.*]] = vector.load %[[buffer_def_1]][%[[o1]]] : memref<64xf32, 3>, vector<2xf32>
-//       CHECK-SCF-IF:   %[[r0:.*]] = vector.load %[[buffer_def_0]][%[[laneid]]] : memref<32xf32, 3>, vector<1xf32>
+//       CHECK-SCF-IF:   %[[o1:.*]] = affine.apply #[[$TIMES2]]()[%[[laneid]]]
+//       CHECK-SCF-IF:   %[[r1:.*]] = vector.transfer_read %[[buffer_def_1]][%[[o1]]], %{{.*}} {in_bounds = [true]} : memref<64xf32, 3>, vector<2xf32>
+//       CHECK-SCF-IF:   %[[r0:.*]] = vector.transfer_read %[[buffer_def_0]][%[[laneid]]], %{{.*}} {in_bounds = [true]} : memref<32xf32, 3>, vector<1xf32>
 //       CHECK-SCF-IF:   "some_use"(%[[r0]]) : (vector<1xf32>) -> ()
 //       CHECK-SCF-IF:   "some_use"(%[[r1]]) : (vector<2xf32>) -> ()
   "some_use"(%r#0) : (vector<1xf32>) -> ()
@@ -631,3 +630,23 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
   }
   return %r : f32
 }
+
+// -----
+
+// CHECK-PROP:   func @lane_dependent_warp_propagate_read
+//  CHECK-PROP-SAME:   %[[ID:.*]]: index
+func.func @lane_dependent_warp_propagate_read(
+    %laneid: index, %src: memref<1x1024xf32>, %dest: memref<1x1024xf32>) {
+  // CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-PROP-NOT: lane_dependent_warp_propagate_read
+  // CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[C0]], %[[ID]]], %{{.*}} : memref<1x1024xf32>, vector<1x1xf32>
+  // CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1x1xf32>, memref<1x1024xf32>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) {
+    %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<1x1024xf32>, vector<1x32xf32>
+    vector.yield %2 : vector<1x32xf32>
+  }
+  vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32>
+  return
+}


        


More information about the Mlir-commits mailing list