[Mlir-commits] [mlir] 1c84800 - [mlir][vector] Add patterns to ppropagate vector distribution
Thomas Raoux
llvmlistbot at llvm.org
Mon Jun 13 09:39:16 PDT 2022
Author: Thomas Raoux
Date: 2022-06-13T16:38:50Z
New Revision: 1c84800c42d2183a29392c175c8d5f20a4be65d2
URL: https://github.com/llvm/llvm-project/commit/1c84800c42d2183a29392c175c8d5f20a4be65d2
DIFF: https://github.com/llvm/llvm-project/commit/1c84800c42d2183a29392c175c8d5f20a4be65d2.diff
LOG: [mlir][vector] Add patterns to ppropagate vector distribution
Add patterns to propagate vector distribution and remove dead
arguments. This handles propagation for several vector operations.
Differential Revision: https://reviews.llvm.org/D127167
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index b95b527d0639c..5af0da2f8528c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -65,6 +65,10 @@ void populateDistributeTransferWriteOpPatterns(
/// region.
void moveScalarUniformCode(WarpExecuteOnLane0Op op);
+/// Collect patterns to propagate warp distribution.
+void populatePropagateWarpVectorDistributionPatterns(
+ RewritePatternSet &pattern);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 586604f6fd6c3..7f8566aa6c477 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/SideEffectUtils.h"
using namespace mlir;
@@ -181,6 +182,60 @@ static bool canBeHoisted(Operation *op,
isSideEffectFree(op) && op->getNumRegions() == 0;
}
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+ std::function<bool(Operation *)> fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
+// Clones `op` into a new operation that takes `operands` and returns
+// `resultTypes`.
+static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
+ Location loc, Operation *op,
+ ArrayRef<Value> operands,
+ ArrayRef<Type> resultTypes) {
+ OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
+ op->getAttrs());
+ return rewriter.create(res);
+}
+
+/// Currently the distribution map is implicit based on the vector shape. In the
+/// future it will be part of the op.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
+/// ...
+/// vector.yield %3 : vector<32x16x64xf32>
+/// }
+/// ```
+/// Would have an implicit map of:
+/// `(d0, d1, d2) -> (d0, d2)`
+static AffineMap calculateImplicitMap(Value yield, Value ret) {
+ auto srcType = yield.getType().cast<VectorType>();
+ auto dstType = ret.getType().cast<VectorType>();
+ SmallVector<AffineExpr> perm;
+ // Check which dimensions of the yield value are
diff erent than the dimensions
+ // of the result to know the distributed dimensions. Then associate each
+ // distributed dimension to an ID in order.
+ for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
+ if (srcType.getDimSize(i) != dstType.getDimSize(i))
+ perm.push_back(getAffineDimExpr(i, yield.getContext()));
+ }
+ auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
+ return map;
+}
+
namespace {
struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -350,6 +405,322 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
DistributionMapFn distributionMapFn;
};
+/// Sink out elementwise op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+/// %3 = arith.addf %1, %2 : vector<32xf32>
+/// vector.yield %3 : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
+/// vector<1xf32>, vector<1xf32>) {
+/// ...
+/// %4 = arith.addf %2, %3 : vector<32xf32>
+/// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
+/// vector<32xf32>
+/// }
+/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
+struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
+ return OpTrait::hasElementwiseMappableTraits(op);
+ });
+ if (!yieldOperand)
+ return failure();
+ Operation *elementWise = yieldOperand->get().getDefiningOp();
+ unsigned operandIndex = yieldOperand->getOperandNumber();
+ Value distributedVal = warpOp.getResult(operandIndex);
+ SmallVector<Value> yieldValues;
+ SmallVector<Type> retTypes;
+ Location loc = warpOp.getLoc();
+ for (OpOperand &operand : elementWise->getOpOperands()) {
+ Type targetType;
+ if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
+ // If the result type is a vector, the operands must also be vectors.
+ auto operandType = operand.get().getType().cast<VectorType>();
+ targetType =
+ VectorType::get(vecType.getShape(), operandType.getElementType());
+ } else {
+ auto operandType = operand.get().getType();
+ assert(!operandType.isa<VectorType>() &&
+ "unexpected yield of vector from op with scalar result type");
+ targetType = operandType;
+ }
+ retTypes.push_back(targetType);
+ yieldValues.push_back(operand.get());
+ }
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, yieldValues, retTypes);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Value> newOperands(elementWise->getOperands().begin(),
+ elementWise->getOperands().end());
+ for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
+ newOperands[i] = newWarpOp.getResult(i + warpOp.getNumResults());
+ }
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Operation *newOp = cloneOpWithOperandsAndTypes(
+ rewriter, loc, elementWise, newOperands,
+ {newWarpOp.getResult(operandIndex).getType()});
+ newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
+ return success();
+ }
+};
+
+/// Sink out transfer_read op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
+// vector<32xf32>
+/// vector.yield %2 : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
+/// vector<1xf32>, vector<1xf32>) {
+/// ...
+/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
+/// vector<32xf32> vector.yield %2 : vector<32xf32>
+/// }
+/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
+struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
+ if (!operand)
+ return failure();
+ auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
+ unsigned operandIndex = operand->getOperandNumber();
+ Value distributedVal = warpOp.getResult(operandIndex);
+
+ SmallVector<Value, 4> indices(read.getIndices().begin(),
+ read.getIndices().end());
+ AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
+ AffineMap indexMap = map.compose(read.getPermutationMap());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(warpOp);
+ for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+ AffineExpr d0, d1;
+ bindDims(read.getContext(), d0, d1);
+ auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ if (!indexExpr)
+ continue;
+ unsigned indexPos = indexExpr.getPosition();
+ unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+ int64_t scale =
+ distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
+ indices[indexPos] =
+ makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
+ {indices[indexPos], warpOp.getLaneid()});
+ }
+ Value newRead = rewriter.create<vector::TransferReadOp>(
+ read.getLoc(), distributedVal.getType(), read.getSource(), indices,
+ read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+ read.getInBoundsAttr());
+ distributedVal.replaceAllUsesWith(newRead);
+ return success();
+ }
+};
+
+/// Remove any result that has no use along with the matching yieldOp operand.
+// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
+struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> yieldValues;
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
+ resultTypes.push_back(result.getType());
+ yieldValues.push_back(yield.getOperand(result.getResultNumber()));
+ }
+ if (yield.getNumOperands() == yieldValues.size())
+ return failure();
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, yieldValues, resultTypes);
+ unsigned resultIndex = 0;
+ for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
+ result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
+ }
+ rewriter.eraseOp(warpOp);
+ return success();
+ }
+};
+
+// If an operand is directly yielded out of the region we can forward it
+// directly and it doesn't need to go through the region.
+struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> yieldValues;
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Value valForwarded;
+ unsigned resultIndex;
+ for (OpOperand &operand : yield->getOpOperands()) {
+ Value result = warpOp.getResult(operand.getOperandNumber());
+ if (result.use_empty())
+ continue;
+
+ // Assume all the values coming from above are uniform.
+ if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
+ if (result.getType() != operand.get().getType())
+ continue;
+ valForwarded = operand.get();
+ resultIndex = operand.getOperandNumber();
+ break;
+ }
+ auto arg = operand.get().dyn_cast<BlockArgument>();
+ if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
+ continue;
+ Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
+ if (result.getType() != warpOperand.getType())
+ continue;
+ valForwarded = warpOperand;
+ resultIndex = operand.getOperandNumber();
+ break;
+ }
+ if (!valForwarded)
+ return failure();
+ warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
+ return success();
+ }
+};
+
+struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
+ Location loc = broadcastOp.getLoc();
+ auto destVecType =
+ warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {broadcastOp.getSource()},
+ {broadcastOp.getSource().getType()});
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value broadcasted = rewriter.create<vector::BroadcastOp>(
+ loc, destVecType, newWarpOp->getResults().back());
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
+
+ return success();
+ }
+};
+
+/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
+/// the scf.ForOp is the last operation in the region so that it doesn't change
+/// the order of execution. This creates a new scf.for region after the
+/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
+/// WarpExecuteOnLane0Op region. Example:
+/// ```
+/// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
+/// ...
+/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
+/// -> (vector<128xf32>) {
+/// ...
+/// scf.yield %r : vector<128xf32>
+/// }
+/// vector.yield %v1 : vector<128xf32>
+/// }
+/// ```
+/// To:
+/// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
+/// ...
+/// vector.yield %v : vector<128xf32>
+/// }
+/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
+/// -> (vector<4xf32>) {
+/// %iw = vector.warp_execute_on_lane_0(%laneid)
+/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
+/// ^bb0(%arg: vector<128xf32>):
+/// ...
+/// vector.yield %ir : vector<128xf32>
+/// }
+/// scf.yield %iw : vector<4xf32>
+/// }
+/// ```
+struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ // Only pick up forOp if it is the last op in the region.
+ Operation *lastNode = yield->getPrevNode();
+ auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
+ if (!forOp)
+ return failure();
+ SmallVector<Value> newOperands;
+ SmallVector<unsigned> resultIdx;
+ // Collect all the outputs coming from the forOp.
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+ continue;
+ auto forResult = yieldOperand.get().cast<OpResult>();
+ newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
+ yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
+ resultIdx.push_back(yieldOperand.getOperandNumber());
+ }
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(warpOp);
+ // Create a new for op outside the region with a WarpExecuteOnLane0Op region
+ // inside.
+ auto newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newOperands);
+ rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
+ auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
+ warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
+ forOp.getResultTypes());
+
+ SmallVector<Value> argMapping;
+ argMapping.push_back(newForOp.getInductionVar());
+ for (Value args : innerWarp.getBody()->getArguments()) {
+ argMapping.push_back(args);
+ }
+ SmallVector<Value> yieldOperands;
+ for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+ yieldOperands.push_back(operand);
+ rewriter.eraseOp(forOp.getBody()->getTerminator());
+ rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+ rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
+ rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
+ rewriter.setInsertionPointAfter(innerWarp);
+ rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+ rewriter.eraseOp(forOp);
+ // Replace the warpOp result coming from the original ForOp.
+ for (const auto &res : llvm::enumerate(resultIdx)) {
+ warpOp.getResult(res.value())
+ .replaceAllUsesWith(newForOp.getResult(res.index()));
+ newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
+ }
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -363,6 +734,13 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
}
+void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
+ WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
+ patterns.getContext());
+}
+
void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
Block *body = warpOp.getBody();
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dc4dfee861fb7..b57791ad04a19 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
@@ -126,4 +127,310 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
}
return
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_dead_result(
+func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
+ // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+ %r:3 = vector.warp_execute_on_lane_0(%laneid)[32] ->
+ (vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+ %2 = "some_def"() : () -> (vector<32xf32>)
+ %3 = "some_def"() : () -> (vector<32xf32>)
+ %4 = "some_def"() : () -> (vector<32xf32>)
+ // CHECK-PROP: vector.yield %{{.*}} : vector<32xf32>
+ vector.yield %2, %3, %4 : vector<32xf32>, vector<32xf32>, vector<32xf32>
+ }
+ // CHECK-PROP: return %[[R]] : vector<1xf32>
+ return %r#1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_propagate_operand(
+// CHECK-PROP-SAME: %[[ID:.*]]: index, %[[V:.*]]: vector<4xf32>)
+func.func @warp_propagate_operand(%laneid: index, %v0: vector<4xf32>)
+ -> (vector<4xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%v0 : vector<4xf32>) -> (vector<4xf32>) {
+ ^bb0(%arg0 : vector<128xf32>) :
+ vector.yield %arg0 : vector<128xf32>
+ }
+ // CHECK-PROP: return %[[V]] : vector<4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-PROP-LABEL: func @warp_propagate_elementwise(
+func.func @warp_propagate_elementwise(%laneid: index, %dest: memref<1024xf32>) {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK-PROP: %[[R:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>, vector<2xf32>, vector<2xf32>)
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] ->
+ (vector<1xf32>, vector<2xf32>) {
+ // CHECK-PROP: %[[V0:.*]] = "some_def"() : () -> vector<32xf32>
+ // CHECK-PROP: %[[V1:.*]] = "some_def"() : () -> vector<32xf32>
+ // CHECK-PROP: %[[V2:.*]] = "some_def"() : () -> vector<64xf32>
+ // CHECK-PROP: %[[V3:.*]] = "some_def"() : () -> vector<64xf32>
+ // CHECK-PROP: vector.yield %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<32xf32>, vector<32xf32>, vector<64xf32>, vector<64xf32>
+ %2 = "some_def"() : () -> (vector<32xf32>)
+ %3 = "some_def"() : () -> (vector<32xf32>)
+ %4 = "some_def"() : () -> (vector<64xf32>)
+ %5 = "some_def"() : () -> (vector<64xf32>)
+ %6 = arith.addf %2, %3 : vector<32xf32>
+ %7 = arith.addf %4, %5 : vector<64xf32>
+ vector.yield %6, %7 : vector<32xf32>, vector<64xf32>
+ }
+ // CHECK-PROP: %[[A0:.*]] = arith.addf %[[R]]#2, %[[R]]#3 : vector<2xf32>
+ // CHECK-PROP: %[[A1:.*]] = arith.addf %[[R]]#0, %[[R]]#1 : vector<1xf32>
+ %id2 = affine.apply #map0()[%laneid]
+ // CHECK-PROP: vector.transfer_write %[[A1]], {{.*}} : vector<1xf32>, memref<1024xf32>
+ // CHECK-PROP: vector.transfer_write %[[A0]], {{.*}} : vector<2xf32>, memref<1024xf32>
+ vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
+ vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32>
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_propagate_scalar_arith(
+// CHECK-PROP: %[[r:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} {
+// CHECK-PROP: %[[some_def0:.*]] = "some_def"
+// CHECK-PROP: %[[some_def1:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[some_def0]], %[[some_def1]]
+// CHECK-PROP: }
+// CHECK-PROP: arith.addf %[[r]]#0, %[[r]]#1 : f32
+func.func @warp_propagate_scalar_arith(%laneid: index) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (f32)
+ %1 = "some_def"() : () -> (f32)
+ %2 = arith.addf %0, %1 : f32
+ vector.yield %2 : f32
+ }
+ vector.print %r : f32
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_propagate_cast(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: %[[result:.*]] = arith.sitofp %{{.*}} : i32 to f32
+// CHECK-PROP: return %[[result]]
+func.func @warp_propagate_cast(%laneid : index, %i : i32) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %casted = arith.sitofp %i : i32 to f32
+ vector.yield %casted : f32
+ }
+ return %r : f32
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-PROP-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-PROP: func @warp_propagate_read
+// CHECK-PROP-SAME: (%[[ID:.*]]: index
+func.func @warp_propagate_read(%laneid: index, %src: memref<1024xf32>, %dest: memref<1024xf32>) {
+// CHECK-PROP-NOT: warp_execute_on_lane_0
+// CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[ID]]], %{{.*}} : memref<1024xf32>, vector<1xf32>
+// CHECK-PROP-DAG: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
+// CHECK-PROP-DAG: %[[R1:.*]] = vector.transfer_read %arg1[%[[ID2]]], %{{.*}} : memref<1024xf32>, vector<2xf32>
+// CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1xf32>, memref<1024xf32>
+// CHECK-PROP: vector.transfer_write %[[R1]], {{.*}} : vector<2xf32>, memref<1024xf32>
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] ->(vector<1xf32>, vector<2xf32>) {
+ %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32>
+ %3 = vector.transfer_read %src[%c32], %cst : memref<1024xf32>, vector<64xf32>
+ vector.yield %2, %3 : vector<32xf32>, vector<64xf32>
+ }
+ %id2 = affine.apply #map0()[%laneid]
+ vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
+ vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32>
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @fold_vector_broadcast(
+// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+// CHECK-PROP: %[[some_def:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[some_def]] : vector<1xf32>
+// CHECK-PROP: vector.print %[[r]] : vector<1xf32>
+func.func @fold_vector_broadcast(%laneid: index) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+ %0 = "some_def"() : () -> (vector<1xf32>)
+ %1 = vector.broadcast %0 : vector<1xf32> to vector<32xf32>
+ vector.yield %1 : vector<32xf32>
+ }
+ vector.print %r : vector<1xf32>
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @extract_vector_broadcast(
+// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+// CHECK-PROP: %[[some_def:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[some_def]] : vector<1xf32>
+// CHECK-PROP: %[[broadcasted:.*]] = vector.broadcast %[[r]] : vector<1xf32> to vector<2xf32>
+// CHECK-PROP: vector.print %[[broadcasted]] : vector<2xf32>
+func.func @extract_vector_broadcast(%laneid: index) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<1xf32>)
+ %1 = vector.broadcast %0 : vector<1xf32> to vector<64xf32>
+ vector.yield %1 : vector<64xf32>
+ }
+ vector.print %r : vector<2xf32>
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @extract_scalar_vector_broadcast(
+// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (f32)
+// CHECK-PROP: %[[some_def:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[some_def]] : f32
+// CHECK-PROP: %[[broadcasted:.*]] = vector.broadcast %[[r]] : f32 to vector<2xf32>
+// CHECK-PROP: vector.print %[[broadcasted]] : vector<2xf32>
+func.func @extract_scalar_vector_broadcast(%laneid: index) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (f32)
+ %1 = vector.broadcast %0 : f32 to vector<64xf32>
+ vector.yield %1 : vector<64xf32>
+ }
+ vector.print %r : vector<2xf32>
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_scf_for(
+// CHECK-PROP: %[[INI:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>) {
+// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: vector.yield %[[INI1]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]) -> (vector<4xf32>) {
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]] : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%[[ARG:.*]]: vector<128xf32>):
+// CHECK-PROP: %[[ACC:.*]] = "some_def"(%[[ARG]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: vector.yield %[[ACC]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> ()
+func.func @warp_scf_for(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+ %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc : vector<128xf32>
+ }
+ vector.yield %3 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_scf_for_swap(
+// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: vector.yield %[[INI1]], %[[INI2]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG1:.*]] = %[[INI]]#0, %[[FARG2:.*]] = %[[INI]]#1) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG1]], %[[FARG2]] : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: ^bb0(%[[ARG1:.*]]: vector<128xf32>, %[[ARG2:.*]]: vector<128xf32>):
+// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%[[ARG1]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: %[[ACC2:.*]] = "some_def"(%[[ARG2]]) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: vector.yield %[[ACC2]], %[[ACC1]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W]]#0, %[[W]]#1 : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+// CHECK-PROP: "some_use"(%[[F]]#1) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_swap(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0:2 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>) {
+ %ini1 = "some_def"() : () -> (vector<128xf32>)
+ %ini2 = "some_def"() : () -> (vector<128xf32>)
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2) -> (vector<128xf32>, vector<128xf32>) {
+ %acc1 = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ %acc2 = "some_def"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc2, %acc1 : vector<128xf32>, vector<128xf32>
+ }
+ vector.yield %3#0, %3#1 : vector<128xf32>, vector<128xf32>
+ }
+ "some_use"(%0#0) : (vector<4xf32>) -> ()
+ "some_use"(%0#1) : (vector<4xf32>) -> ()
+ return
+}
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 * 4)>
+#map1 = affine_map<()[s0] -> (s0 * 128 + 128)>
+#map2 = affine_map<()[s0] -> (s0 * 4 + 128)>
+
+// CHECK-PROP-LABEL: func @warp_scf_for_multiple_yield(
+// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+// CHECK-PROP-NEXT: "some_def"() : () -> vector<32xf32>
+// CHECK-PROP-NEXT: vector.yield %{{.*}} : vector<32xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+// CHECK-PROP: %{{.*}}:2 = scf.for {{.*}} -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<?xf32>, vector<4xf32>
+// CHECK-PROP: arith.addf {{.*}} : vector<4xf32>
+// CHECK-PROP: arith.addf {{.*}} : vector<4xf32>
+// CHECK-PROP: scf.yield {{.*}} : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+ %c256 = arith.constant 256 : index
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0:3 = vector.warp_execute_on_lane_0(%arg0)[32] ->
+ (vector<1xf32>, vector<4xf32>, vector<4xf32>) {
+ %def = "some_def"() : () -> (vector<32xf32>)
+ %r1 = vector.transfer_read %arg2[%c0], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+ %r2 = vector.transfer_read %arg2[%c128], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %r1, %arg5 = %r2)
+ -> (vector<128xf32>, vector<128xf32>) {
+ %o1 = affine.apply #map1()[%arg3]
+ %o2 = affine.apply #map2()[%arg3]
+ %4 = vector.transfer_read %arg1[%o1], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+ %5 = vector.transfer_read %arg1[%o2], %cst {in_bounds = [true]} : memref<?xf32>, vector<128xf32>
+ %6 = arith.addf %4, %arg4 : vector<128xf32>
+ %7 = arith.addf %5, %arg5 : vector<128xf32>
+ scf.yield %6, %7 : vector<128xf32>, vector<128xf32>
+ }
+ vector.yield %def, %3#0, %3#1 : vector<32xf32>, vector<128xf32>, vector<128xf32>
+ }
+ %1 = affine.apply #map()[%arg0]
+ vector.transfer_write %0#1, %arg2[%1] {in_bounds = [true]} : vector<4xf32>, memref<?xf32>
+ %2 = affine.apply #map2()[%arg0]
+ vector.transfer_write %0#2, %arg2[%2] {in_bounds = [true]} : vector<4xf32>, memref<?xf32>
+ "some_use"(%0#0) : (vector<1xf32>) -> ()
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
index 2205079f246a4..159c677b96319 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
@@ -11,6 +11,31 @@
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
// RUN: FileCheck %s
+// Run the same test cases with distribution and propagation.
+// RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \
+// RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
+// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: -gpu-kernel-outlining \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \
+// RUN: -gpu-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
+// RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
+// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: -gpu-kernel-outlining \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \
+// RUN: -gpu-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
func.func @gpu_func(%arg1: memref<32xf32>, %arg2: memref<32xf32>) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e1ffddc5f0687..cfdeb1e632e0b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -835,6 +835,10 @@ struct TestVectorDistribution
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
+ Option<bool> propagateDistribution{
+ *this, "propagate-distribution",
+ llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
+
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
@@ -862,7 +866,11 @@ struct TestVectorDistribution
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
-
+ if (propagateDistribution) {
+ RewritePatternSet patterns(ctx);
+ vector::populatePropagateWarpVectorDistributionPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
More information about the Mlir-commits
mailing list