[Mlir-commits] [mlir] eca7698 - [mlir][vector] NFC: Expose castAwayContractionLeadingOneDim
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 21 09:44:52 PDT 2023
Author: Lei Zhang
Date: 2023-04-21T09:41:14-07:00
New Revision: eca7698a979ebc2338c075d63ac52c4c21b19cb1
URL: https://github.com/llvm/llvm-project/commit/eca7698a979ebc2338c075d63ac52c4c21b19cb1
DIFF: https://github.com/llvm/llvm-project/commit/eca7698a979ebc2338c075d63ac52c4c21b19cb1.diff
LOG: [mlir][vector] NFC: Expose castAwayContractionLeadingOneDim
This commit exposes the transformation behind the pattern.
It is useful for more targeted application on a specific op
for once.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D148758
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4763b6525b934..ed25021e421f4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
} // namespace arith
namespace vector {
+class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
class VectorDialect;
@@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Cast away the leading unit dim, if exists, for the given contract op.
+/// Return success if the transformation applies; return failure otherwise.
+LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+ RewriterBase &rewriter);
+
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 849e0442bc7e1..6105e87573c23 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
}
};
+} // namespace
+
+LogicalResult
+mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+ RewriterBase &rewriter) {
+ VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
+ if (oldAccType == nullptr)
+ return failure();
+ if (oldAccType.getRank() < 2)
+ return failure();
+ if (oldAccType.getShape()[0] != 1)
+ return failure();
+ // currently we support only dropping one dim but the pattern can be applied
+ // greedily to drop more.
+ int64_t dropDim = 1;
+
+ auto oldIndexingMaps = contractOp.getIndexingMapsArray();
+ SmallVector<AffineMap> newIndexingMaps;
+
+ auto oldIteratorTypes = contractOp.getIteratorTypes();
+ SmallVector<Attribute> newIteratorTypes;
+
+ int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
+
+ if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
+ // only parallel type iterators can be dropped.
+ return failure();
+
+ for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
+ int64_t currDim = it.index();
+ if (currDim == dimToDrop)
+ continue;
+ newIteratorTypes.push_back(it.value());
+ }
+
+ SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
+ contractOp.getAcc()};
+ SmallVector<Value> newOperands;
+
+ for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
+ // Check if the dim to be dropped exists as a leading dim in the operand
+ // if it does then we use vector.extract to drop it.
+ bool validExtract = false;
+ SmallVector<AffineExpr> results;
+ auto map = it.value();
+ int64_t orginalZeroDim = it.value().getDimPosition(0);
+ if (orginalZeroDim != dimToDrop) {
+ // There are two reasons to be in this path, 1. We need to
+ // tranpose the operand to make the dim to be dropped
+ // leading. 2. The dim to be dropped does not exist and in
+ // that case we dont want to add a unit tranpose but we must
+ // check all the indices to make sure this is the case.
+ bool tranposeNeeded = false;
+ SmallVector<int64_t> perm;
+ SmallVector<AffineExpr> transposeResults;
+
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t currDim = map.getDimPosition(i);
+ if (currDim == dimToDrop) {
+ tranposeNeeded = true;
+ perm.insert(perm.begin(), i);
+ auto targetExpr = rewriter.getAffineDimExpr(currDim);
+ transposeResults.insert(transposeResults.begin(), targetExpr);
+ } else {
+ perm.push_back(i);
+ auto targetExpr = rewriter.getAffineDimExpr(currDim);
+ transposeResults.push_back(targetExpr);
+ }
+ }
+ // Do the tranpose now if needed so that we can drop the
+ // correct dim using extract later.
+ if (tranposeNeeded) {
+ map = AffineMap::get(map.getNumDims(), 0, transposeResults,
+ contractOp.getContext());
+ operands[it.index()] = rewriter.create<vector::TransposeOp>(
+ contractOp.getLoc(), operands[it.index()], perm);
+ }
+ }
+ // We have taken care to have the dim to be dropped be
+ // the leading dim. If its still not leading that means it
+ // does not exist in this operand and hence we do not need
+ // an extract.
+ if (map.getDimPosition(0) == dimToDrop)
+ validExtract = true;
+
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t currDim = map.getDimPosition(i);
+ if (currDim == dimToDrop)
+ // This is the dim we are dropping.
+ continue;
+ auto targetExpr = rewriter.getAffineDimExpr(
+ currDim < dimToDrop ? currDim : currDim - 1);
+ results.push_back(targetExpr);
+ }
+ newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
+ contractOp.getContext()));
+ // Extract if its a valid extraction, otherwise use the operand
+ // without extraction.
+ newOperands.push_back(
+ validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
+ operands[it.index()],
+ splatZero(dropDim))
+ : operands[it.index()]);
+ }
+ auto newContractOp = rewriter.create<vector::ContractionOp>(
+ contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+ rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ contractOp, contractOp->getResultTypes()[0], newContractOp);
+ return success();
+}
+
+namespace {
+
/// Turns vector.contract on vector with leading 1 dimensions into
/// vector.extract followed by vector.contract on vector without leading
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
@@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
- if (oldAccType == nullptr)
- return failure();
- if (oldAccType.getRank() < 2)
- return failure();
- if (oldAccType.getShape()[0] != 1)
- return failure();
- // currently we support only dropping one dim but the pattern can be applied
- // greedily to drop more.
- int64_t dropDim = 1;
-
- auto oldIndexingMaps = contractOp.getIndexingMapsArray();
- SmallVector<AffineMap> newIndexingMaps;
-
- auto oldIteratorTypes = contractOp.getIteratorTypes();
- SmallVector<Attribute> newIteratorTypes;
-
- int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
-
- if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
- // only parallel type iterators can be dropped.
- return failure();
-
- for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
- int64_t currDim = it.index();
- if (currDim == dimToDrop)
- continue;
- newIteratorTypes.push_back(it.value());
- }
-
- SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
- contractOp.getAcc()};
- SmallVector<Value> newOperands;
-
- for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
- // Check if the dim to be dropped exists as a leading dim in the operand
- // if it does then we use vector.extract to drop it.
- bool validExtract = false;
- SmallVector<AffineExpr> results;
- auto map = it.value();
- int64_t orginalZeroDim = it.value().getDimPosition(0);
- if (orginalZeroDim != dimToDrop) {
- // There are two reasons to be in this path, 1. We need to
- // tranpose the operand to make the dim to be dropped
- // leading. 2. The dim to be dropped does not exist and in
- // that case we dont want to add a unit tranpose but we must
- // check all the indices to make sure this is the case.
- bool tranposeNeeded = false;
- SmallVector<int64_t> perm;
- SmallVector<AffineExpr> transposeResults;
-
- for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t currDim = map.getDimPosition(i);
- if (currDim == dimToDrop) {
- tranposeNeeded = true;
- perm.insert(perm.begin(), i);
- auto targetExpr = rewriter.getAffineDimExpr(currDim);
- transposeResults.insert(transposeResults.begin(), targetExpr);
- } else {
- perm.push_back(i);
- auto targetExpr = rewriter.getAffineDimExpr(currDim);
- transposeResults.push_back(targetExpr);
- }
- }
- // Do the tranpose now if needed so that we can drop the
- // correct dim using extract later.
- if (tranposeNeeded) {
- map = AffineMap::get(map.getNumDims(), 0, transposeResults,
- contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
- }
- }
- // We have taken care to have the dim to be dropped be
- // the leading dim. If its still not leading that means it
- // does not exist in this operand and hence we do not need
- // an extract.
- if (map.getDimPosition(0) == dimToDrop)
- validExtract = true;
-
- for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t currDim = map.getDimPosition(i);
- if (currDim == dimToDrop)
- // This is the dim we are dropping.
- continue;
- auto targetExpr = rewriter.getAffineDimExpr(
- currDim < dimToDrop ? currDim : currDim - 1);
- results.push_back(targetExpr);
- }
- newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
- contractOp.getContext()));
- // Extract if its a valid extraction, otherwise use the operand
- // without extraction.
- newOperands.push_back(validExtract
- ? rewriter.create<vector::ExtractOp>(
- contractOp.getLoc(), operands[it.index()],
- splatZero(dropDim))
- : operands[it.index()]);
- }
- auto newContractOp = rewriter.create<vector::ContractionOp>(
- contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
- rewriter.getAffineMapArrayAttr(newIndexingMaps),
- rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- contractOp, contractOp->getResultTypes()[0], newContractOp);
- return success();
+ return castAwayContractionLeadingOneDim(contractOp, rewriter);
}
};
More information about the Mlir-commits
mailing list