[Mlir-commits] [mlir] fa8a10a - [mlir][Vector] Refactor vector distribution and fix an issue related to non-homogenous transfer indices.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 2 02:18:36 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-02T02:18:26-07:00
New Revision: fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d
URL: https://github.com/llvm/llvm-project/commit/fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d
DIFF: https://github.com/llvm/llvm-project/commit/fa8a10a1fd0dfd6e54822a033fb8b0900bd19c6d.diff
LOG: [mlir][Vector] Refactor vector distribution and fix an issue related to non-homogenous transfer indices.
Running: `mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize -verify-each=0`.
Prior to this revision, IR resembling the following would be produced:
```
%4 = "vector.load"(%3, %arg0) : (memref<1x32xf32, 3>, index) -> vector<1x1xf32>
```
This fails verification since it needs 2 indices to load but only 1 is provided.
Differential Revision: https://reviews.llvm.org/D133106
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8ecb8986fd9f..22cabaf979e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -10,9 +10,9 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/Transforms/SideEffectUtils.h"
#include "llvm/ADT/SetVector.h"
#include <utility>
@@ -20,122 +20,115 @@
using namespace mlir;
using namespace mlir::vector;
-static LogicalResult
-rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- const WarpExecuteOnLane0LoweringOptions &options) {
- assert(warpOp.getBodyRegion().hasOneBlock() &&
- "expected WarpOp with single block");
- Block *warpOpBody = &warpOp.getBodyRegion().front();
- Location loc = warpOp.getLoc();
+/// TODO: add an analysis step that determines which vector dimension should be
+/// used for distribution.
+static llvm::Optional<int64_t>
+getDistributedVectorDim(VectorType distributedVectorType) {
+ return (distributedVectorType)
+ ? llvm::Optional<int64_t>(distributedVectorType.getRank() - 1)
+ : llvm::None;
+}
- // Passed all checks. Start rewriting.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(warpOp);
+static llvm::Optional<int64_t>
+getDistributedSize(VectorType distributedVectorType) {
+ auto dim = getDistributedVectorDim(distributedVectorType);
+ return (dim) ? llvm::Optional<int64_t>(distributedVectorType.getDimSize(*dim))
+ : llvm::None;
+}
- // Create scf.if op.
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- warpOp.getLaneid(), c0);
- auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
- /*withElseRegion=*/false);
- rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
-
- // Store vectors that are defined outside of warpOp into the scratch pad
- // buffer.
- SmallVector<Value> bbArgReplacements;
- for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
- Value val = it.value();
- Value bbArg = warpOpBody->getArgument(it.index());
-
- rewriter.setInsertionPoint(ifOp);
- Value buffer =
- options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
-
- // Store arg vector into buffer.
- rewriter.setInsertionPoint(ifOp);
- auto vectorType = val.getType().cast<VectorType>();
- int64_t storeSize = vectorType.getShape()[0];
- Value storeOffset = rewriter.create<arith::MulIOp>(
- loc, warpOp.getLaneid(),
- rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
- rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
-
- // Load bbArg vector from buffer.
- rewriter.setInsertionPointToStart(ifOp.thenBlock());
- auto bbArgType = bbArg.getType().cast<VectorType>();
- Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
- bbArgReplacements.push_back(loadOp);
+namespace {
+
+/// Helper struct to create the load / store operations that permit transit
+/// through the parallel / sequential and the sequential / parallel boundaries
+/// when performing `rewriteWarpOpToScfFor`.
+///
+/// All this assumes the vector distribution occurs along the most minor
+/// distributed vector dimension.
+/// TODO: which is expected to be a multiple of the warp size ?
+/// TODO: add an analysis step that determines which vector dimension should
+/// be used for distribution.
+struct DistributedLoadStoreHelper {
+ DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
+ Value laneId, Value zero)
+ : sequentialVal(sequentialVal), distributedVal(distributedVal),
+ laneId(laneId), zero(zero) {
+ sequentialType = sequentialVal.getType();
+ distributedType = distributedVal.getType();
+ sequentialVectorType = sequentialType.dyn_cast<VectorType>();
+ distributedVectorType = distributedType.dyn_cast<VectorType>();
}
- // Insert sync after all the stores and before all the loads.
- if (!warpOp.getArgs().empty()) {
- rewriter.setInsertionPoint(ifOp);
- options.warpSyncronizationFn(loc, rewriter, warpOp);
+ Value buildDistributedOffset(RewriterBase &b, Location loc) {
+ auto maybeDistributedSize = getDistributedSize(distributedVectorType);
+ assert(maybeDistributedSize &&
+ "at this point, a distributed size must be determined");
+ AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
+ return b.createOrFold<AffineApplyOp>(loc, tid * (*maybeDistributedSize),
+ ArrayRef<Value>{laneId});
}
- // Move body of warpOp to ifOp.
- rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
-
- // Rewrite terminator and compute replacements of WarpOp results.
- SmallVector<Value> replacements;
- auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
- Location yieldLoc = yieldOp.getLoc();
- for (const auto &it : llvm::enumerate(yieldOp.operands())) {
- Value val = it.value();
- Type resultType = warpOp->getResultTypes()[it.index()];
- rewriter.setInsertionPoint(ifOp);
- Value buffer =
- options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
-
- // Store yielded value into buffer.
- rewriter.setInsertionPoint(yieldOp);
- if (val.getType().isa<VectorType>())
- rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
- else
- rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
-
- // Load value from buffer (after warpOp).
- rewriter.setInsertionPointAfter(ifOp);
- if (resultType == val.getType()) {
- // Result type and yielded value type are the same. This is a broadcast.
- // E.g.:
- // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
- // vector.yield %cst : f32
- // }
- // Both types are f32. The constant %cst is broadcasted to all lanes.
- // This is described in more detail in the documentation of the op.
- Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
- replacements.push_back(loadOp);
- } else {
- auto loadedVectorType = resultType.cast<VectorType>();
- int64_t loadSize = loadedVectorType.getShape()[0];
-
- // loadOffset = laneid * loadSize
- Value loadOffset = rewriter.create<arith::MulIOp>(
- loc, warpOp.getLaneid(),
- rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
- Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
- buffer, loadOffset);
- replacements.push_back(loadOp);
+ Operation *buildStore(RewriterBase &b, Location loc, Value val,
+ Value buffer) {
+ assert((val == distributedVal || val == sequentialVal) &&
+ "Must store either the preregistered distributed or the "
+ "preregistered sequential value.");
+ // Vector case must use vector::TransferWriteOp which will later lower to
+ // vector.store of memref.store depending on further lowerings.
+ if (val.getType().isa<VectorType>()) {
+ SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+ auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
+ assert(maybeDistributedDim && "must be able to deduce distributed dim");
+ if (val == distributedVal)
+ indices[*maybeDistributedDim] =
+ (val == distributedVal) ? buildDistributedOffset(b, loc) : zero;
+ SmallVector<bool> inBounds(indices.size(), true);
+ return b.create<vector::TransferWriteOp>(
+ loc, val, buffer, indices,
+ ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
+ // Scalar case can directly use memref.store.
+ return b.create<memref::StoreOp>(loc, val, buffer, zero);
}
- // Insert sync after all the stores and before all the loads.
- if (!yieldOp.operands().empty()) {
- rewriter.setInsertionPointAfter(ifOp);
- options.warpSyncronizationFn(loc, rewriter, warpOp);
+ Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
+ bool broadcastMode = false) {
+ // When broadcastMode is true, this is a broadcast.
+ // E.g.:
+ // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+ // vector.yield %cst : f32
+ // }
+ // Both types are f32. The constant %cst is broadcasted to all lanes.
+ // This is described in more detail in the documentation of the op.
+ if (broadcastMode)
+ return b.create<memref::LoadOp>(loc, buffer, zero);
+
+ // Other cases must be vector atm.
+ // Vector case must use vector::TransferReadOp which will later lower to
+ // vector.read of memref.read depending on further lowerings.
+ assert(type.isa<VectorType>() && "must be a vector type");
+ assert((type == distributedVectorType || type == sequentialVectorType) &&
+ "Must store either the preregistered distributed or the "
+ "preregistered sequential type.");
+ auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
+ assert(maybeDistributedDim && "must be able to deduce distributed dim");
+ SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+ if (type == distributedVectorType) {
+ indices[*maybeDistributedDim] = buildDistributedOffset(b, loc);
+ } else {
+ indices[*maybeDistributedDim] = zero;
+ }
+ SmallVector<bool> inBounds(indices.size(), true);
+ return b.create<vector::TransferReadOp>(
+ loc, type.cast<VectorType>(), buffer, indices,
+ ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
- // Delete terminator and add empty scf.yield.
- rewriter.eraseOp(yieldOp);
- rewriter.setInsertionPointToEnd(ifOp.thenBlock());
- rewriter.create<scf::YieldOp>(yieldLoc);
-
- // Compute replacements for WarpOp results.
- rewriter.replaceOp(warpOp, replacements);
+ Value sequentialVal, distributedVal, laneId, zero;
+ Type sequentialType, distributedType;
+ VectorType sequentialVectorType, distributedVectorType;
+};
- return success();
-}
+} // namespace
/// Helper to create a new WarpExecuteOnLane0Op with
diff erent signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
@@ -261,6 +254,37 @@ static AffineMap calculateImplicitMap(Value yield, Value ret) {
namespace {
+/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
+/// thread `laneId` executes the entirety of the computation.
+///
+/// After the transformation:
+/// - the IR within the scf.if op can be thought of as executing sequentially
+/// (from the point of view of threads along `laneId`).
+/// - the IR outside of the scf.if op can be thought of as executing in
+/// parallel (from the point of view of threads along `laneId`).
+///
+/// Values that need to transit through the parallel / sequential and the
+/// sequential / parallel boundaries do so via reads and writes to a temporary
+/// memory location.
+///
+/// The transformation proceeds in multiple steps:
+/// 1. Create the scf.if op.
+/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
+/// within the scf.if to transit the values captured from above.
+/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
+/// consistent within the scf.if.
+/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
+/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
+/// transit the values returned by the op.
+/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
+/// consistent after the scf.if.
+/// 7. Perform late cleanups.
+///
+/// All this assumes the vector distribution occurs along the most minor
+/// distributed vector dimension.
+/// TODO: which is expected to be a multiple of the warp size ?
+/// TODO: add an analysis step that determines which vector dimension should be
+/// used for distribution.
struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpToScfForPattern(MLIRContext *context,
const WarpExecuteOnLane0LoweringOptions &options,
@@ -270,7 +294,106 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- return rewriteWarpOpToScfFor(rewriter, warpOp, options);
+ assert(warpOp.getBodyRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+ Block *warpOpBody = &warpOp.getBodyRegion().front();
+ Location loc = warpOp.getLoc();
+
+ // Passed all checks. Start rewriting.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+
+ // Step 1: Create scf.if op.
+ Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value isLane0 = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
+ auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
+ /*withElseRegion=*/false);
+ rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
+
+ // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
+ // reads within the scf.if to transit the values captured from above.
+ SmallVector<Value> bbArgReplacements;
+ for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
+ Value sequentialVal = warpOpBody->getArgument(it.index());
+ Value distributedVal = it.value();
+ DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
+ warpOp.getLaneid(), c0);
+
+ // Create buffer before the ifOp.
+ rewriter.setInsertionPoint(ifOp);
+ Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
+ sequentialVal.getType());
+ // Store distributed vector into buffer, before the ifOp.
+ helper.buildStore(rewriter, loc, distributedVal, buffer);
+ // Load sequential vector from buffer, inside the ifOp.
+ rewriter.setInsertionPointToStart(ifOp.thenBlock());
+ bbArgReplacements.push_back(
+ helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
+ }
+
+ // Step 3. Insert sync after all the stores and before all the loads.
+ if (!warpOp.getArgs().empty()) {
+ rewriter.setInsertionPoint(ifOp);
+ options.warpSyncronizationFn(loc, rewriter, warpOp);
+ }
+
+ // Step 4. Move body of warpOp to ifOp.
+ rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
+
+ // Step 5. Insert appropriate writes within scf.if and reads after the
+ // scf.if to transit the values returned by the op.
+ // TODO: at this point, we can reuse the shared memory from previous
+ // buffers.
+ SmallVector<Value> replacements;
+ auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
+ Location yieldLoc = yieldOp.getLoc();
+ for (const auto &it : llvm::enumerate(yieldOp.operands())) {
+ Value sequentialVal = it.value();
+ Value distributedVal = warpOp->getResult(it.index());
+ DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
+ warpOp.getLaneid(), c0);
+
+ // Create buffer before the ifOp.
+ rewriter.setInsertionPoint(ifOp);
+ Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
+ sequentialVal.getType());
+
+ // Store yielded value into buffer, inside the ifOp, before the
+ // terminator.
+ rewriter.setInsertionPoint(yieldOp);
+ helper.buildStore(rewriter, loc, sequentialVal, buffer);
+
+ // Load distributed value from buffer, after the warpOp.
+ rewriter.setInsertionPointAfter(ifOp);
+ bool broadcastMode =
+ (sequentialVal.getType() == distributedVal.getType());
+ // Result type and yielded value type are the same. This is a broadcast.
+ // E.g.:
+ // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+ // vector.yield %cst : f32
+ // }
+ // Both types are f32. The constant %cst is broadcasted to all lanes.
+ // This is described in more detail in the documentation of the op.
+ replacements.push_back(helper.buildLoad(
+ rewriter, loc, distributedVal.getType(), buffer, broadcastMode));
+ }
+
+ // Step 6. Insert sync after all the stores and before all the loads.
+ if (!yieldOp.operands().empty()) {
+ rewriter.setInsertionPointAfter(ifOp);
+ options.warpSyncronizationFn(loc, rewriter, warpOp);
+ }
+
+ // Step 7. Delete terminator and add empty scf.yield.
+ rewriter.eraseOp(yieldOp);
+ rewriter.setInsertionPointToEnd(ifOp.thenBlock());
+ rewriter.create<scf::YieldOp>(yieldLoc);
+
+ // Compute replacements for WarpOp results.
+ rewriter.replaceOp(warpOp, replacements);
+
+ return success();
}
private:
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index cceb04b4d52f..8f7e867b5b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -4,7 +4,9 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
-
+// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-SCF-IF-DAG: #[[$TIMES8:.*]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_128xf32 : memref<128xf32, 3>
@@ -16,17 +18,14 @@
func.func @rewrite_warp_op_to_scf_if(%laneid: index,
%v0: vector<4xf32>, %v1: vector<8xf32>) {
// CHECK-SCF-IF-DAG: %[[c0:.*]] = arith.constant 0 : index
-// CHECK-SCF-IF-DAG: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-SCF-IF-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-SCF-IF-DAG: %[[c8:.*]] = arith.constant 8 : index
// CHECK-SCF-IF: %[[is_lane_0:.*]] = arith.cmpi eq, %[[laneid]], %[[c0]]
// CHECK-SCF-IF: %[[buffer_v0:.*]] = memref.get_global @__shared_128xf32
-// CHECK-SCF-IF: %[[s0:.*]] = arith.muli %[[laneid]], %[[c4]]
-// CHECK-SCF-IF: vector.store %[[v0]], %[[buffer_v0]][%[[s0]]]
+// CHECK-SCF-IF: %[[s0:.*]] = affine.apply #[[$TIMES4]]()[%[[laneid]]]
+// CHECK-SCF-IF: vector.transfer_write %[[v0]], %[[buffer_v0]][%[[s0]]]
// CHECK-SCF-IF: %[[buffer_v1:.*]] = memref.get_global @__shared_256xf32
-// CHECK-SCF-IF: %[[s1:.*]] = arith.muli %[[laneid]], %[[c8]]
-// CHECK-SCF-IF: vector.store %[[v1]], %[[buffer_v1]][%[[s1]]]
+// CHECK-SCF-IF: %[[s1:.*]] = affine.apply #[[$TIMES8]]()[%[[laneid]]]
+// CHECK-SCF-IF: vector.transfer_write %[[v1]], %[[buffer_v1]][%[[s1]]]
// CHECK-SCF-IF-DAG: gpu.barrier
// CHECK-SCF-IF-DAG: %[[buffer_def_0:.*]] = memref.get_global @__shared_32xf32
@@ -36,21 +35,21 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index,
%r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
args(%v0, %v1 : vector<4xf32>, vector<8xf32>) -> (vector<1xf32>, vector<2xf32>) {
^bb0(%arg0: vector<128xf32>, %arg1: vector<256xf32>):
-// CHECK-SCF-IF: %[[arg1:.*]] = vector.load %[[buffer_v1]][%[[c0]]] : memref<256xf32, 3>, vector<256xf32>
-// CHECK-SCF-IF: %[[arg0:.*]] = vector.load %[[buffer_v0]][%[[c0]]] : memref<128xf32, 3>, vector<128xf32>
+// CHECK-SCF-IF: %[[arg1:.*]] = vector.transfer_read %[[buffer_v1]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<256xf32, 3>, vector<256xf32>
+// CHECK-SCF-IF: %[[arg0:.*]] = vector.transfer_read %[[buffer_v0]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<128xf32, 3>, vector<128xf32>
// CHECK-SCF-IF: %[[def_0:.*]] = "some_def"(%[[arg0]]) : (vector<128xf32>) -> vector<32xf32>
// CHECK-SCF-IF: %[[def_1:.*]] = "some_def"(%[[arg1]]) : (vector<256xf32>) -> vector<64xf32>
%2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32>
%3 = "some_def"(%arg1) : (vector<256xf32>) -> vector<64xf32>
-// CHECK-SCF-IF: vector.store %[[def_0]], %[[buffer_def_0]][%[[c0]]]
-// CHECK-SCF-IF: vector.store %[[def_1]], %[[buffer_def_1]][%[[c0]]]
+// CHECK-SCF-IF: vector.transfer_write %[[def_0]], %[[buffer_def_0]][%[[c0]]]
+// CHECK-SCF-IF: vector.transfer_write %[[def_1]], %[[buffer_def_1]][%[[c0]]]
vector.yield %2, %3 : vector<32xf32>, vector<64xf32>
}
// CHECK-SCF-IF: }
// CHECK-SCF-IF: gpu.barrier
-// CHECK-SCF-IF: %[[o1:.*]] = arith.muli %[[laneid]], %[[c2]]
-// CHECK-SCF-IF: %[[r1:.*]] = vector.load %[[buffer_def_1]][%[[o1]]] : memref<64xf32, 3>, vector<2xf32>
-// CHECK-SCF-IF: %[[r0:.*]] = vector.load %[[buffer_def_0]][%[[laneid]]] : memref<32xf32, 3>, vector<1xf32>
+// CHECK-SCF-IF: %[[o1:.*]] = affine.apply #[[$TIMES2]]()[%[[laneid]]]
+// CHECK-SCF-IF: %[[r1:.*]] = vector.transfer_read %[[buffer_def_1]][%[[o1]]], %{{.*}} {in_bounds = [true]} : memref<64xf32, 3>, vector<2xf32>
+// CHECK-SCF-IF: %[[r0:.*]] = vector.transfer_read %[[buffer_def_0]][%[[laneid]]], %{{.*}} {in_bounds = [true]} : memref<32xf32, 3>, vector<1xf32>
// CHECK-SCF-IF: "some_use"(%[[r0]]) : (vector<1xf32>) -> ()
// CHECK-SCF-IF: "some_use"(%[[r1]]) : (vector<2xf32>) -> ()
"some_use"(%r#0) : (vector<1xf32>) -> ()
@@ -631,3 +630,23 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
}
return %r : f32
}
+
+// -----
+
+// CHECK-PROP: func @lane_dependent_warp_propagate_read
+// CHECK-PROP-SAME: %[[ID:.*]]: index
+func.func @lane_dependent_warp_propagate_read(
+ %laneid: index, %src: memref<1x1024xf32>, %dest: memref<1x1024xf32>) {
+ // CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-PROP-NOT: lane_dependent_warp_propagate_read
+ // CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[C0]], %[[ID]]], %{{.*}} : memref<1x1024xf32>, vector<1x1xf32>
+ // CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1x1xf32>, memref<1x1024xf32>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<1x1024xf32>, vector<1x32xf32>
+ vector.yield %2 : vector<1x32xf32>
+ }
+ vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32>
+ return
+}
More information about the Mlir-commits
mailing list