[Mlir-commits] [mlir] cb5de7c - [mlir][Vector] NFC - Compress vector to outerproduct lowering.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 2 14:24:59 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-02T21:23:59Z
New Revision: cb5de7c813f976dd458bd2a7f40702ba648bf650
URL: https://github.com/llvm/llvm-project/commit/cb5de7c813f976dd458bd2a7f40702ba648bf650
DIFF: https://github.com/llvm/llvm-project/commit/cb5de7c813f976dd458bd2a7f40702ba648bf650.diff
LOG: [mlir][Vector] NFC - Compress vector to outerproduct lowering.
The implementation has become too unwieldy and cognitive overhead wins.
Instead compress the implementation in preparation for additional lowering paths.
This is a resubmit of https://reviews.llvm.org/D105359 without ordering ambiguities.
Differential Revision: https://reviews.llvm.org/D105367
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1a7d2e80d56f..3342f9a7482b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1816,6 +1816,72 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return success();
}
+namespace {
+struct IteratorType {
+ IteratorType(StringRef strRef) : strRef(strRef) {}
+ bool isOfType(Attribute attr) const {
+ auto sAttr = attr.dyn_cast<StringAttr>();
+ return sAttr && sAttr.getValue() == strRef;
+ }
+ StringRef strRef;
+};
+struct Par : public IteratorType {
+ Par() : IteratorType(getParallelIteratorTypeName()) {}
+};
+struct Red : public IteratorType {
+ Red() : IteratorType(getReductionIteratorTypeName()) {}
+};
+
+// Unroll outer-products along reduction.
+struct UnrolledOuterProductEmitter {
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+
+ UnrolledOuterProductEmitter(PatternRewriter &rewriter,
+ vector::ContractionOp op)
+ : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
+ iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+
+ Value t(Value v) {
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ return rewriter.create<vector::TransposeOp>(loc, v, perm);
+ }
+
+ bool iters(ArrayRef<IteratorType> its) {
+ if (its.size() != iterators.size())
+ return false;
+ for (int i = 0, e = its.size(); i != e; ++i) {
+ if (!its[i].isOfType(iterators[i]))
+ return false;
+ }
+ return true;
+ }
+
+ bool layout(MapList l) {
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ return maps == infer(l);
+ }
+
+ LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
+ assert(reductionSize > 0);
+ for (int64_t k = 0; k < reductionSize; ++k) {
+ Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+ Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+ res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
+ res, kind);
+ }
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ PatternRewriter &rewriter;
+ Location loc;
+ vector::CombiningKind kind;
+ ArrayAttr iterators;
+ SmallVector<AffineMap, 4> maps;
+ Operation *op;
+};
+} // namespace
+
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
@@ -1844,104 +1910,68 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (failed(filter(op)))
return failure();
- Location loc = op.getLoc();
- int64_t reductionSize = 0;
VectorType lhsType = op.getLhsType();
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
// Set up the parallel/reduction structure in right form.
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- static constexpr std::array<int64_t, 2> perm = {1, 0};
- auto iteratorTypes = op.iterator_types().getValue();
- SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
- if (isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])) {
- //
- // Two outer parallel, one inner reduction (matmat flavor).
- //
- if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- // This is the classical row-major matmul. Just permute the lhs.
- reductionSize = lhsType.getDimSize(1);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
- reductionSize = lhsType.getDimSize(1);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- // No need to permute anything.
- reductionSize = lhsType.getDimSize(0);
- } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
- // Just permute the rhs.
- reductionSize = lhsType.getDimSize(0);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
- // This is the classical row-major matmul. Just permute the lhs.
- reductionSize = lhsType.getDimSize(1);
- Value tmp = rhs;
- rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- lhs = tmp;
- } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
- reductionSize = lhsType.getDimSize(1);
- Value tmp = rhs;
- rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
- } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
- // No need to permute anything, but still swap lhs and rhs.
- reductionSize = lhsType.getDimSize(0);
- std::swap(lhs, rhs);
- } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
- // Just permute the rhs.
- reductionSize = lhsType.getDimSize(0);
- Value tmp = lhs;
- lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- rhs = tmp;
- } else {
- return failure();
+
+ //
+ // Two outer parallel, one inner reduction (matmat flavor).
+ //
+ UnrolledOuterProductEmitter e(rewriter, op);
+ if (e.iters({Par(), Par(), Red()})) {
+ // Classical row-major matmul: Just permute the lhs.
+ if (e.layout({{m, k}, {k, n}, {m, n}}))
+ return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ if (e.layout({{m, k}, {n, k}, {m, n}})) {
+ Value tlhs = e.t(lhs);
+ return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
}
- } else if (isParallelIterator(iteratorTypes[0]) &&
- isReductionIterator(iteratorTypes[1])) {
- //
- // One outer parallel, one inner reduction (matvec flavor)
- //
- if (maps == infer({{m, n}, {n}, {m}})) {
- // Case mat-vec: transpose.
- reductionSize = lhsType.getDimSize(1);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{n, m}, {n}, {m}})) {
- // Case mat-trans-vec: ready to go.
- reductionSize = lhsType.getDimSize(0);
- } else if (maps == infer({{n}, {m, n}, {m}})) {
- // Case vec-mat: swap and transpose.
- reductionSize = lhsType.getDimSize(0);
- std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{n}, {n, m}, {m}})) {
- // Case vec-mat-trans: swap and ready to go.
- reductionSize = lhsType.getDimSize(0);
- std::swap(lhs, rhs);
- } else {
- return failure();
+ // No need to permute anything.
+ if (e.layout({{k, m}, {k, n}, {m, n}}))
+ return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+ // Just permute the rhs.
+ if (e.layout({{k, m}, {n, k}, {m, n}}))
+ return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
+ // Transposed output: swap RHS and LHS.
+ // Classical row-major matmul: permute the lhs.
+ if (e.layout({{m, k}, {k, n}, {n, m}}))
+ return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ if (e.layout({{m, k}, {n, k}, {n, m}})) {
+ Value trhs = e.t(rhs);
+ return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
}
- } else {
+ if (e.layout({{k, m}, {k, n}, {n, m}}))
+ return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+ if (e.layout({{k, m}, {n, k}, {n, m}}))
+ return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
return failure();
}
- assert(reductionSize > 0);
-
- // Unroll outer-products along reduction.
- for (int64_t k = 0; k < reductionSize; ++k) {
- Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
- Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
- res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
- b, res, op.kind());
+
+ //
+ // One outer parallel, one inner reduction (matvec flavor)
+ //
+ if (e.iters({Par(), Red()})) {
+ // Case mat-vec: transpose.
+ if (e.layout({{m, n}, {n}, {m}}))
+ return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+ // Case mat-trans-vec: ready to go.
+ if (e.layout({{n, m}, {n}, {m}}))
+ return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+ // Case vec-mat: swap and transpose.
+ if (e.layout({{n}, {m, n}, {m}}))
+ return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+ // Case vec-mat-trans: swap and ready to go.
+ if (e.layout({{n}, {n, m}, {m}}))
+ return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+ return failure();
}
- rewriter.replaceOp(op, res);
- return success();
+
+ return failure();
}
LogicalResult
More information about the Mlir-commits
mailing list