[Mlir-commits] [mlir] ba10daa - [mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue May 26 12:45:16 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-26T15:40:55-04:00
New Revision: ba10daa820fa868816eed2b85e70197d354ebfe6
URL: https://github.com/llvm/llvm-project/commit/ba10daa820fa868816eed2b85e70197d354ebfe6
DIFF: https://github.com/llvm/llvm-project/commit/ba10daa820fa868816eed2b85e70197d354ebfe6.diff
LOG: [mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.
This revision expands the types of vector contractions that can be lowered to vector.outerproduct.
All 8 permutation cases are support.
The idiomatic manipulation of AffineMap written declaratively makes this straightforward.
In the process a bug with the vector.contract verifier was uncovered.
The vector shape verification part of the contract op is rewritten to use AffineMap composition.
One bug in the vector `ops.mlir` test is fixed and a new case not yet captured is added
to the vector`invalid.mlir` test.
Differential Revision: https://reviews.llvm.org/D80393
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 5a36aabfab75..02d276256076 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -81,12 +81,24 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
+constexpr bool isParallelIterator(Attribute attr) {
+ auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+ return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+}
/// Use to encode that a particular iterator type has reduction semantics.
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
+constexpr bool isReductionIterator(Attribute attr) {
+ auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+ return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+}
/// Use to encode that a particular iterator type has window semantics.
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
+constexpr bool isWindowIterator(Attribute attr) {
+ auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+ return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
+}
/// Use to encode that a particular iterator type has window semantics.
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1574edb34494..63891d1004d4 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -223,8 +223,9 @@ static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
return true;
}
-static bool verifyOutputShape(
- VectorType lhsType, VectorType rhsType, Type accType, Type resType,
+static LogicalResult verifyOutputShape(
+ ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
+ Type resType,
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
DenseSet<int64_t> lhsContractingDimSet;
@@ -256,26 +257,56 @@ static bool verifyOutputShape(
if (expectedResultDims.size() == 0) {
// No batch or free dimension implies a scalar result.
if (resType.isa<VectorType>() || accType.isa<VectorType>())
- return false;
-
+ return op.emitOpError("invalid accumulator/result vector shape");
} else {
// At least one batch or free dimension implies a vector result.
auto resVectorType = resType.dyn_cast<VectorType>();
auto accVectorType = accType.dyn_cast<VectorType>();
if (!resVectorType || !accVectorType)
- return false;
-
- // Verify dimension from 'resType' against 'expectedResultDims'.
- if (resVectorType.getShape().size() != expectedResultDims.size() ||
- accVectorType.getShape().size() != expectedResultDims.size())
- return false;
- for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) {
- if (resVectorType.getDimSize(i) != expectedResultDims[i] ||
- accVectorType.getDimSize(i) != expectedResultDims[i])
- return false;
+ return op.emitOpError("invalid accumulator/result vector shape");
+
+ // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
+ // types fully define the result vector type. This assumes the affine maps
+ // are well-formed, which must have been verified already.
+ MLIRContext *ctx = op.getContext();
+ AffineMap lhsMap = op.getIndexingMaps()[0];
+ AffineMap rhsMap = op.getIndexingMaps()[1];
+ SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
+ for (auto pair :
+ {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
+ VectorType v = pair.first;
+ auto map = pair.second;
+ for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
+ unsigned pos = map.getResult(idx).cast<AffineDimExpr>().getPosition();
+ if (!extents[pos])
+ extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
+ }
}
+ assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
+ "expected extent along all dimensions.");
+
+ AffineMap resMap = op.getIndexingMaps()[2];
+ auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
+ /*symCount=*/0, extents, ctx);
+ // Compose the resMap with the extentsMap, which is a constant map.
+ AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
+ assert(llvm::all_of(
+ expectedMap.getResults(),
+ [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
+ "expected constant extent along all dimensions.");
+ // Extract the expected shape and build the type.
+ auto expectedShape = llvm::to_vector<4>(
+ llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
+ return e.cast<AffineConstantExpr>().getValue();
+ }));
+ auto expected =
+ VectorType::get(expectedShape, resVectorType.getElementType());
+ if (resVectorType != expected || accVectorType != expected)
+ return op.emitOpError(
+ "invalid accumulator/result vector shape, expected: ")
+ << expected;
}
- return true;
+ return success();
}
static LogicalResult verify(ContractionOp op) {
@@ -329,9 +360,9 @@ static LogicalResult verify(ContractionOp op) {
return op.emitOpError("invalid batch dimension map");
// Verify 'accType' and 'resType' shape.
- if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
- batchDimMap))
- return op.emitOpError("invalid accumulator/result vector shape");
+ if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
+ contractingDimMap, batchDimMap)))
+ return failure();
// Verify that either two vector masks are set or none are set.
auto lhsMaskType = op.getLHSVectorMaskType();
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 44ff03a04f22..491ad62affcb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1454,10 +1454,17 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
if (llvm::size(op.masks()) != 0)
return failure();
+ auto iteratorTypes = op.iterator_types().getValue();
+ if (!isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]))
+ return failure();
+
if (vectorTransformsOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul ||
!isRowMajorMatmul(op.indexing_maps()))
return failure();
+
return success();
}
@@ -1503,34 +1510,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-void ContractionOpToOuterProductOpLowering::rewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
- VectorType lhsType = op.getLhsType();
- // TODO(ntv) other modes.
- // We know we are in row-major.
- bool transposeLhs = false;
- unsigned reductionSize =
- transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
-
- // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
- // transpose it to extract the proper vector<m x f32>. Otherwise, just take
- // the lhs.
- Value lhs = transposeLhs
- ? op.lhs()
- : rewriter.create<vector::TransposeOp>(
- op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
- Value res = op.acc();
- // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
- for (unsigned k = 0; k < reductionSize; ++k) {
- Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
- Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
- res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
- }
- rewriter.replaceOp(op, res);
-}
-
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
LogicalResult
ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
// TODO(ajcbik): implement masks
@@ -1538,12 +1519,104 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
return failure();
if (vectorTransformsOptions.vectorContractLowering !=
- vector::VectorContractLowering::OuterProduct ||
- !isRowMajorMatmul(op.indexing_maps()))
+ vector::VectorContractLowering::OuterProduct)
+ return failure();
+
+ // Transpose arguments to make them ready for lowering to OuterProduct. The
+ // constraint to match is that we must load full rows at a time with
+ // vector::ExtractOp.
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m, n, k;
+ bindDims(op.getContext(), m, n, k);
+ auto iteratorTypes = op.iterator_types().getValue();
+ if (!isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]))
+ return failure();
+ SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+ // When lowering to outerproduct we can support all permutations.
+ if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
+ maps != infer({{m, k}, {n, k}, {m, n}}) &&
+ maps != infer({{k, m}, {k, n}, {m, n}}) &&
+ maps != infer({{k, m}, {n, k}, {m, n}}) &&
+ maps != infer({{m, k}, {k, n}, {n, m}}) &&
+ maps != infer({{m, k}, {n, k}, {n, m}}) &&
+ maps != infer({{k, m}, {k, n}, {n, m}}) &&
+ maps != infer({{k, m}, {n, k}, {n, m}}))
return failure();
return success();
}
+void ContractionOpToOuterProductOpLowering::rewrite(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ unsigned reductionSize = 0;
+ VectorType lhsType = op.getLhsType();
+ Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
+
+ // Transpose arguments to make them ready for lowering to OuterProduct. The
+ // constraint to match is that we must load full rows at a time with
+ // vector::ExtractOp.
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m, n, k;
+ bindDims(rewriter.getContext(), m, n, k);
+ SmallVector<int64_t, 2> perm{1, 0};
+ SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+ // First batch of cases, no need to output permute.
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ // This is the classical row-major matmul. Just permute the lhs.
+ reductionSize = lhsType.getShape()[1];
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ reductionSize = lhsType.getShape()[1];
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ // No need to permute anything.
+ reductionSize = lhsType.getShape()[0];
+ } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+ // Just permute the rhs.
+ reductionSize = lhsType.getShape()[0];
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ }
+ // Second batch of cases, reshuffle to avoid output permute.
+ else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ // This is the classical row-major matmul. Just permute the lhs.
+ reductionSize = lhsType.getShape()[1];
+ Value tmp = rhs;
+ rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = tmp;
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ reductionSize = lhsType.getShape()[1];
+ Value tmp = rhs;
+ rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ // No need to permute anything, but still swap lhs and rhs.
+ reductionSize = lhsType.getShape()[0];
+ std::swap(lhs, rhs);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ // Just permute the rhs.
+ reductionSize = lhsType.getShape()[0];
+ Value tmp = lhs;
+ lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = tmp;
+ }
+ assert(reductionSize > 0);
+
+ // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+ for (unsigned k = 0; k < reductionSize; ++k) {
+ Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
+ Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
+ res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
+ }
+ rewriter.replaceOp(op, res);
+}
+
/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c18cf38edfc9..cc72511a6e78 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -767,6 +767,26 @@ func @contraction(%arg0: vector<4x3xi32>,
// -----
+#contraction_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#contraction_trait = {
+ indexing_maps = #contraction_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<3x2xf32>
+{
+// expected-error at +1 {{invalid accumulator/result vector shape, expected: 'vector<3x2xf32>'}}
+ %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+// -----
+
func @create_mask() {
%c2 = constant 2 : index
%c3 = constant 3 : index
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c194cbe23811..57c03c903fe8 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -160,9 +160,11 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
-#contraction_accesses1 = [
+#contraction_accesses1 = [ // 7, 8, 16, 15
affine_map<(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2)>,
+ // 8, 16, 7, 5
affine_map<(f0, f1, f2, f3, c0, c1) -> (f1, c1, c0, f3)>,
+ // 8, 8, 15, 5
affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)>
]
#contraction_trait1 = {
@@ -172,7 +174,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
}
// CHECK-LABEL: contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
- %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
+ %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
%arg4 : index) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
@@ -181,16 +183,16 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
// Test contraction with only contracting dims. In this case the lhs/rhs
// dimension of size 8 will be considered a parallel dim for lhs/rhs and will
// appear twice in the output.
- // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+ // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
- : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+ : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// Test contraction with optional vector mask arguments.
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
- // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+ // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask
- : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+ : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 7eea3baa8d87..1dd2f377a29c 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -681,3 +681,219 @@ func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
return %0 : vector<2x3xi1>
}
+
+#matmat_accesses_0 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_0 = {
+ indexing_maps = #matmat_accesses_0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_0
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_1 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+ indexing_maps = #matmat_accesses_1,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_1
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_2 = [
+ affine_map<(m, n, k) -> (k, m)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+ indexing_maps = #matmat_accesses_2,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_2
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_3 = [
+ affine_map<(m, n, k) -> (k, m)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+ indexing_maps = #matmat_accesses_3,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_3
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_4 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+ indexing_maps = #matmat_accesses_4,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_4
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+ %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_5 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_5 = {
+ indexing_maps = #matmat_accesses_5,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_5
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+ %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_6 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_6 = {
+ indexing_maps = #matmat_accesses_6,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_6
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+ %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_7 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_7 = {
+ indexing_maps = #matmat_accesses_7,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_7
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+ %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
More information about the Mlir-commits
mailing list