[Mlir-commits] [mlir] 6080271 - [mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 15 09:02:56 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-15T16:01:59Z
New Revision: 60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c
URL: https://github.com/llvm/llvm-project/commit/60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c
DIFF: https://github.com/llvm/llvm-project/commit/60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c.diff
LOG: [mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class
Differential Revision: https://reviews.llvm.org/D111893
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 14f8b29f66892..e318bd5cf3e1a 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -19,11 +19,14 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
+class PatternRewriter;
+
/// Tests whether the given maps describe a row major matmul. The test is
/// permutation-invariant. Note that this only checks the affine maps from an
/// operation, so does not perform any checks on the math being performed within
@@ -132,6 +135,60 @@ inline StringRef toString(IteratorType t) {
llvm_unreachable("Unsupported IteratorType");
}
+/// Helper StructuredGenerator class to manipulate and rewrite ops with
+/// `StructuredOpInterface`. This is templated for now because VectorOps do not
+/// yet implement the StructuredOpInterface itself.
+template <typename StructuredOpInterface>
+class StructuredGenerator {
+public:
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+
+ struct IteratorType {
+ IteratorType(StringRef strRef) : strRef(strRef) {}
+ bool isOfType(Attribute attr) const {
+ auto sAttr = attr.dyn_cast<StringAttr>();
+ return sAttr && sAttr.getValue() == strRef;
+ }
+ StringRef strRef;
+ };
+ struct Par : public IteratorType {
+ Par() : IteratorType(getParallelIteratorTypeName()) {}
+ };
+ struct Red : public IteratorType {
+ Red() : IteratorType(getReductionIteratorTypeName()) {}
+ };
+ struct Win : public IteratorType {
+ Win() : IteratorType(getWindowIteratorTypeName()) {}
+ };
+
+ StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
+ : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
+ iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+
+ bool iters(ArrayRef<IteratorType> its) {
+ if (its.size() != iterators.size())
+ return false;
+ for (int i = 0, e = its.size(); i != e; ++i) {
+ if (!its[i].isOfType(iterators[i]))
+ return false;
+ }
+ return true;
+ }
+
+ bool layout(MapList l) {
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ return maps == infer(l);
+ }
+
+protected:
+ PatternRewriter &rewriter;
+ MLIRContext *ctx;
+ Location loc;
+ ArrayAttr iterators;
+ SmallVector<AffineMap, 4> maps;
+ Operation *op;
+};
+
} // end namespace mlir
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2b83c23730aad..d1b9835452b7d 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1252,35 +1252,22 @@ struct Red : public IteratorType {
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
-// Unroll outer-products along reduction.
-struct UnrolledOuterProductEmitter {
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+/// Generate a vector implementation for matmat, matvec and tmatvec.
+/// This unrolls outer-products along the reduction dimension.
+struct UnrolledOuterProductGenerator
+ : public StructuredGenerator<vector::ContractionOp> {
- UnrolledOuterProductEmitter(PatternRewriter &rewriter,
- vector::ContractionOp op)
- : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
- iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+ UnrolledOuterProductGenerator(PatternRewriter &rewriter,
+ vector::ContractionOp op)
+ : StructuredGenerator<vector::ContractionOp>(rewriter, op),
+ kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
+ lhsType(op.getLhsType()) {}
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
- bool iters(ArrayRef<IteratorType> its) {
- if (its.size() != iterators.size())
- return false;
- for (int i = 0, e = its.size(); i != e; ++i) {
- if (!its[i].isOfType(iterators[i]))
- return false;
- }
- return true;
- }
-
- bool layout(MapList l) {
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
- return maps == infer(l);
- }
-
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0);
for (int64_t k = 0; k < reductionSize; ++k) {
@@ -1293,128 +1280,132 @@ struct UnrolledOuterProductEmitter {
return success();
}
- PatternRewriter &rewriter;
- Location loc;
- vector::CombiningKind kind;
- ArrayAttr iterators;
- SmallVector<AffineMap, 4> maps;
- Operation *op;
-};
-} // namespace
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-/// %at = vector.transpose %a, [1, 0]
-/// %bRow0 = vector.extract %b[0]
-/// %atRow0 = vector.extract %at[0]
-/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
-/// ...
-/// %bRowK = vector.extract %b[K]
-/// %atRowK = vector.extract %at[K]
-/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
-/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
- // TODO: implement masks
- if (llvm::size(op.masks()) != 0)
- return failure();
-
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::OuterProduct)
- return failure();
-
- if (failed(filter(op)))
- return failure();
-
- VectorType lhsType = op.getLhsType();
- Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
-
- //
- // Two outer parallel, one inner reduction (matmat flavor).
- //
- UnrolledOuterProductEmitter e(rewriter, op);
- if (e.iters({Par(), Par(), Red()})) {
- // Set up the parallel/reduction structure in right form.
+ /// Two outer parallel, one inner reduction (matmat flavor).
+ LogicalResult matmat() {
+ if (!iters({Par(), Par(), Red()}))
+ return failure();
+ // Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs.
- if (e.layout({{m, k}, {k, n}, {m, n}}))
- return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {k, n}, {m, n}}))
+ return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
- if (e.layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = e.t(lhs);
- return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {n, k}, {m, n}})) {
+ Value tlhs = t(lhs);
+ return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1));
}
// No need to permute anything.
- if (e.layout({{k, m}, {k, n}, {m, n}}))
- return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {k, n}, {m, n}}))
+ return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Just permute the rhs.
- if (e.layout({{k, m}, {n, k}, {m, n}}))
- return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {n, k}, {m, n}}))
+ return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
- if (e.layout({{m, k}, {k, n}, {n, m}}))
- return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {k, n}, {n, m}}))
+ return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
- if (e.layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = e.t(rhs);
- return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {n, k}, {n, m}})) {
+ Value trhs = t(rhs);
+ return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1));
}
- if (e.layout({{k, m}, {k, n}, {n, m}}))
- return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
- if (e.layout({{k, m}, {n, k}, {n, m}}))
- return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {k, n}, {n, m}}))
+ return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {n, k}, {n, m}}))
+ return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
return failure();
}
- //
- // One outer parallel, one inner reduction (matvec flavor)
- //
- if (e.iters({Par(), Red()})) {
+ /// One outer parallel, one inner reduction (matvec flavor)
+ LogicalResult matvec() {
+ if (!iters({Par(), Red()}))
+ return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
// Case mat-vec: transpose.
- if (e.layout({{m, k}, {k}, {m}}))
- return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {k}, {m}}))
+ return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
- if (e.layout({{k, m}, {k}, {m}}))
- return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {k}, {m}}))
+ return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
- if (e.layout({{k}, {m, k}, {m}}))
- return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+ if (layout({{k}, {m, k}, {m}}))
+ return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
- if (e.layout({{k}, {k, m}, {m}}))
- return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+ if (layout({{k}, {k, m}, {m}}))
+ return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
//
// One outer reduction, one inner parallel (tmatvec flavor)
//
- if (e.iters({Red(), Par()})) {
+ LogicalResult tmatvec() {
+ if (!iters({Red(), Par()}))
+ return failure();
AffineExpr k, m;
bindDims(rewriter.getContext(), k, m);
// Case mat-vec: transpose.
- if (e.layout({{m, k}, {k}, {m}}))
- return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+ if (layout({{m, k}, {k}, {m}}))
+ return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
- if (e.layout({{k, m}, {k}, {m}}))
- return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {k}, {m}}))
+ return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
- if (e.layout({{k}, {m, k}, {m}}))
- return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+ if (layout({{k}, {m, k}, {m}}))
+ return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
- if (e.layout({{k}, {k, m}, {m}}))
- return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+ if (layout({{k}, {k, m}, {m}}))
+ return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
+private:
+ vector::CombiningKind kind;
+ Value lhs, rhs, res;
+ VectorType lhsType;
+};
+} // namespace
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+/// %at = vector.transpose %a, [1, 0]
+/// %bRow0 = vector.extract %b[0]
+/// %atRow0 = vector.extract %at[0]
+/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
+/// ...
+/// %bRowK = vector.extract %b[K]
+/// %atRowK = vector.extract %at[K]
+/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
+ // TODO: implement masks
+ if (llvm::size(op.masks()) != 0)
+ return failure();
+
+ if (vectorTransformOptions.vectorContractLowering !=
+ vector::VectorContractLowering::OuterProduct)
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ UnrolledOuterProductGenerator e(rewriter, op);
+ if (succeeded(e.matmat()))
+ return success();
+ if (succeeded(e.matvec()))
+ return success();
+ if (succeeded(e.tmatvec()))
+ return success();
+
return failure();
}
More information about the Mlir-commits
mailing list