[Mlir-commits] [mlir] 3f906c5 - [mlir][Vector] Add 2-D vector contract lowering to ReduceOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Aug 7 03:21:26 PDT 2020
Author: Nicolas Vasilache
Date: 2020-08-07T06:17:48-04:00
New Revision: 3f906c54a2def26ba0c407a309c7c30ce0e0cc83
URL: https://github.com/llvm/llvm-project/commit/3f906c54a2def26ba0c407a309c7c30ce0e0cc83
DIFF: https://github.com/llvm/llvm-project/commit/3f906c54a2def26ba0c407a309c7c30ce0e0cc83.diff
LOG: [mlir][Vector] Add 2-D vector contract lowering to ReduceOp
This new pattern mixes vector.transpose and direct lowering to vector.reduce.
This allows more progressive lowering than immediately going to insert/extract and
composes more nicely with other canonicalizations.
This has 2 use cases:
1. for very wide vectors the generated IR may be much smaller
2. when we have a custom lowering for transpose ops we can target it directly
rather than rely LLVM
Differential Revision: https://reviews.llvm.org/D85428
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 35855b3b2137..9587c56c0255 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -258,6 +258,51 @@ class ContractionOpToOuterProductOpLowering
FilterConstraintType filter;
};
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+/// %out = constant ... : vector<MxNxelt_type>
+/// %bt = vector.transpose %b, [1, 0]
+/// %aRow0 = vector.extract %a[0]
+/// %btRow0 = vector.extract %bt[0]
+/// %c00 = vector.reduce %atRow0, %bRow0
+/// %out00 = vector.insert %c00, %out[0, 0]
+/// ...
+/// %aRowLast = vector.extract %at[M-1]
+/// %btRowLast = vector.extract %b[N-1]
+/// %cLastLast = vector.reduce %atRowLast, %bRowLast
+/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToDotLowering(
+ vector::VectorTransformsOptions vectorTransformsOptions,
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ vectorTransformsOptions(vectorTransformsOptions),
+ filter(defaultFilter) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
+};
+
/// Progressive lowering of ContractionOp.
///
/// One:
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 218715318a86..b60b6b39a2b7 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -137,8 +137,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getVectorType",
/*args=*/(ins),
/*methodBody=*/"",
- /*defaultImplementation=*/
- "return $_op.vector().getType().template cast<VectorType>();"
+ /*defaultImplementation=*/[{
+ return $_op.vector().getType().template dyn_cast<VectorType>();
+ }]
>,
InterfaceMethod<
/*desc=*/[{ Return the number of dimensions that participate in the
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 922168947ccf..16d10e558b5e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1754,6 +1754,121 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
return success();
}
+LogicalResult
+ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
+ // TODO: implement masks
+ if (llvm::size(op.masks()) != 0)
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ if (vectorTransformsOptions.vectorContractLowering !=
+ vector::VectorContractLowering::Dot)
+ return failure();
+
+ auto iteratorTypes = op.iterator_types().getValue();
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ Location loc = op.getLoc();
+ Value lhs = op.lhs(), rhs = op.rhs();
+
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m, n, k;
+ bindDims(rewriter.getContext(), m, n, k);
+ SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+ //
+ // In the following we wish to make the reduction dimension innermost so we
+ // can load vectors and just fmul + reduce into a scalar.
+ //
+ if (isParallelIterator(iteratorTypes[0]) &&
+ isParallelIterator(iteratorTypes[1]) &&
+ isReductionIterator(iteratorTypes[2])) {
+ //
+ // Two outer parallel, one inner reduction (matmat flavor).
+ //
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+ // No need to permute anything.
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ // This is the classical row-major matmul. Just permute the lhs.
+ Value tmp = lhs;
+ lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = tmp;
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ std::swap(lhs, rhs);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ Value tmp = lhs;
+ lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ Value tmp = rhs;
+ rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = tmp;
+ } else {
+ return failure();
+ }
+ } else if (isParallelIterator(iteratorTypes[0]) &&
+ isReductionIterator(iteratorTypes[1])) {
+ //
+ // One outer parallel, one inner reduction (matvec flavor)
+ //
+ if (maps == infer({{m, n}, {n}, {m}})) {
+ // No need to permute anything.
+ } else if (maps == infer({{n, m}, {n}, {m}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{n}, {m, n}, {m}})) {
+ std::swap(lhs, rhs);
+ } else if (maps == infer({{n}, {n, m}, {m}})) {
+ std::swap(lhs, rhs);
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else {
+ return failure();
+ }
+ } else {
+ return failure();
+ }
+
+ VectorType dstType = op.getResultType().cast<VectorType>();
+ assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
+ "Expected dst type of rank 1 or 2");
+
+ unsigned rank = dstType.getRank();
+ unsigned dstRows = dstType.getShape()[0];
+ unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
+
+ // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+ Value res =
+ rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
+ for (unsigned r = 0; r < dstRows; ++r) {
+ Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+ for (unsigned c = 0; c < dstColumns; ++c) {
+ Value b = rank == 1
+ ? rhs
+ : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
+ Value reduced = rewriter.create<vector::ReductionOp>(
+ op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
+ m, ValueRange{});
+
+ SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
+ : SmallVector<int64_t, 2>{r, c};
+ res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+ }
+ }
+ if (auto acc = op.acc())
+ res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
@@ -1795,6 +1910,9 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
if (succeeded(pat2.matchAndRewrite(op, rewriter)))
return success();
+ ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
+ if (succeeded(pat3.matchAndRewrite(op, rewriter)))
+ return success();
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 6dae907b8bb0..e34e3428c185 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -43,16 +43,15 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
-// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK: return %[[T9]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32>
+// CHECK: return %[[T10]] : vector<2xf32>
func @extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
@@ -78,16 +77,15 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK: %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
+// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[A]] : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK: %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32>
-// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
+// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[A]] : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK: return %[[T9]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32>
+// CHECK: return %[[T10]] : vector<2xf32>
func @extract_contract3(%arg0: vector<3xf32>,
%arg1: vector<2x3xf32>,
@@ -112,47 +110,31 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// ... bunch of extract insert to transpose B into Bt
+// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32>
-// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
-// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
+// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T2]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xf32> into f32
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
//
-// CHECK: %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
-// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
-// CHECK: %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32>
-// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
-// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
-// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
+// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T12]] : vector<2xf32>
+// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xf32> into f32
+// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
//
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
-// CHECK: %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
-// CHECK: %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
-// CHECK: %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32>
-// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32>
-// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32
-// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
+// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T24]] : vector<2xf32>
+// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xf32> into f32
+// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
//
-// CHECK: %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
-// CHECK: %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
-// CHECK: %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32>
-// CHECK: %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32>
-// CHECK: %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32
+// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
+// CHECK: %[[T41:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
+// CHECK: %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xf32> into f32
+// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
//
-// CHECK: %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T52:.*]] = addf %[[T43]], %[[C]] : vector<2x2xf32>
// CHECK: return %[[T52]] : vector<2x2xf32>
func @extract_contract4(%arg0: vector<2x2xf32>,
@@ -574,6 +556,31 @@ func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32>
//
// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32>
+
+// REDUCE-LABEL: func @matmul
+// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//
+// REDUCE: %[[RES:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32>
+//
+// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
+// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32>
+// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32>
+// REDUCE-NEXT: %[[s00:.*]] = vector.reduction "add", %[[ab00]] : vector<4xf32> into f32
+// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32>
+//
+// ...
+//
+// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
+// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32>
+// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32>
+// REDUCE-NEXT: %[[s12:.*]] = vector.reduction "add", %[[ab12]] : vector<4xf32> into f32
+// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32>
+//
+// REDUCE: return %[[c3]] : vector<2x3xf32>
func @matmul(%arg0: vector<2x4xf32>,
%arg1: vector<4x3xf32>,
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
@@ -1056,7 +1063,3 @@ func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg
: vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
return %0 : vector<3x4xf32>
}
-
-
-
-
More information about the Mlir-commits
mailing list