[Mlir-commits] [mlir] 0ba9ee9 - [mlir] [VectorOps] Framework for progressive lowering of vector.contract
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 19 11:36:36 PST 2020
Author: aartbik
Date: 2020-02-19T11:36:11-08:00
New Revision: 0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2
URL: https://github.com/llvm/llvm-project/commit/0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2
DIFF: https://github.com/llvm/llvm-project/commit/0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2.diff
LOG: [mlir] [VectorOps] Framework for progressive lowering of vector.contract
Summary:
Lowers all free/batch dimensions in a vector.contract progressively
into simpler vector.contract operations until a direct vector.reduction
operation is reached. Then lowers 1-D reductions into vector.reduce.
Still TBD:
multi-dimensional contractions that remain after removing all the parallel dims
Reviewers: nicolasvasilache, andydavis1, rriddle
Reviewed By: andydavis1
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74797
Added:
Modified:
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index fe62666a2838..61a3d556a70b 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
#include "mlir/Dialect/VectorOps/VectorUtils.h"
@@ -864,6 +865,19 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
};
/// Progressive lowering of ConstractionOp.
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+/// ..
+/// %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a fma/reduction op.
+///
+/// TODO(ajcbik): break down into transpose/reshape/cast ops
+/// when they become available to avoid code dup
+/// TODO(ajcbik): investigate lowering order impact on performance
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -874,16 +888,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
if (llvm::size(op.masks()) != 0)
return matchFailure();
- auto loc = op.getLoc();
- VectorType lhsType = op.getLhsType();
- VectorType rhsType = op.getRhsType();
- Type resType = op.getResultType();
-
- // Find first batch dimension in lhs/rhs, and lower when found.
+ // Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
- // TODO(ajcbik): implement batch
- return matchFailure();
+ int64_t lhsIndex = batchDimMap[0].first;
+ int64_t rhsIndex = batchDimMap[0].second;
+ rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
+ return matchSuccess();
}
// Collect contracting dimensions.
@@ -896,24 +907,35 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
rhsContractingDimSet.insert(dimPair.second);
}
- // Find free dimension in lhs/rhs, and lower first when found.
- for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
- if (lhsContractingDimSet.count(i) == 0) {
- // TODO(ajcbik): implement free
- return matchFailure();
+ // Find first free dimension in LHS, and lower when found.
+ VectorType lhsType = op.getLhsType();
+ for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
+ ++lhsIndex) {
+ if (lhsContractingDimSet.count(lhsIndex) == 0) {
+ rewriter.replaceOp(
+ op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
+ return matchSuccess();
}
}
- for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
- if (rhsContractingDimSet.count(i) == 0) {
- // TODO(ajcbik): implement free
- return matchFailure();
+
+ // Find first free dimension in RHS, and lower when found.
+ VectorType rhsType = op.getRhsType();
+ for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
+ ++rhsIndex) {
+ if (rhsContractingDimSet.count(rhsIndex) == 0) {
+ rewriter.replaceOp(
+ op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
+ return matchSuccess();
}
}
- // Only contraction dimensions remain.
+ // Lower the only remaining contraction dimensions.
+ // TODO(ajcbik): handle multi-dim reductions
+ auto loc = op.getLoc();
+ Type resType = op.getResultType();
if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
rhsType.getRank() == 1) {
- // Handle reduction into scalar.
+
Value zero = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
@@ -924,9 +946,191 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
op.acc());
return matchSuccess();
}
- // TODO(ajcbik): implement more contraction
+
return matchFailure();
}
+
+private:
+ // Lower one parallel dimension.
+ // TODO(ajcbik): consider reusing existing contract unrolling
+ Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex, PatternRewriter &rewriter) const {
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ VectorType resType = op.getResultType().cast<VectorType>();
+ // Find the iterator type index and result index.
+ SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+ int64_t iterIndex = -1;
+ int64_t dimSize = -1;
+ if (lhsIndex >= 0) {
+ iterIndex =
+ iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
+ assert((rhsIndex < 0 || iterIndex == iMap[1]
+ .getResult(rhsIndex)
+ .cast<AffineDimExpr>()
+ .getPosition()) &&
+ "parallel index should be free in LHS or batch in LHS/RHS");
+ dimSize = lhsType.getDimSize(lhsIndex);
+ } else {
+ assert(rhsIndex >= 0 && "missing parallel index");
+ iterIndex =
+ iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
+ dimSize = rhsType.getDimSize(rhsIndex);
+ }
+ assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
+ Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
+ assert(lookup.hasValue() && "parallel index not listed in reduction");
+ int64_t resIndex = lookup.getValue();
+ // Construct new iterator types.
+ ArrayAttr iteratorTypes = op.iterator_types();
+ SmallVector<Attribute, 4> lowIterTypes;
+ for (auto it : llvm::enumerate(iteratorTypes)) {
+ int64_t idx = it.index();
+ if (idx == iterIndex) {
+ assert(it.value().cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName() &&
+ "parallel index not marked as such");
+ continue;
+ }
+ lowIterTypes.push_back(it.value());
+ }
+ // Construct new affine map array attribute.
+ SmallVector<AffineMap, 4> lowIndexingMaps;
+ lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+ lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+ lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+ auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+ // Construct new iterator types array attribute.
+ auto lowIter = rewriter.getArrayAttr(lowIterTypes);
+ // Unroll into a series of lower dimensional vector.contract ops.
+ Location loc = op.getLoc();
+ Value result = zeroVector(loc, resType, rewriter);
+ for (int64_t d = 0; d < dimSize; ++d) {
+ auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+ auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
+ Value lowContract = rewriter.create<vector::ContractionOp>(
+ loc, lhs, rhs, acc, lowAffine, lowIter);
+ result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
+ rewriter);
+ }
+ return result;
+ }
+
+ // Helper method to construct a zero vector.
+ static Value zeroVector(Location loc, VectorType vType,
+ PatternRewriter &rewriter) {
+ Type eltType = vType.getElementType();
+ Value zero = rewriter.create<ConstantOp>(loc, eltType,
+ rewriter.getZeroAttr(eltType));
+ return rewriter.create<SplatOp>(loc, vType, zero);
+ }
+
+ // Helper to find an index in an affine map.
+ static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ if (idx == index)
+ return i;
+ }
+ return None;
+ }
+
+ // Helper to construct an affine map with one index removed.
+ static AffineMap adjustMap(AffineMap map, int64_t index,
+ PatternRewriter &rewriter) {
+ SmallVector<AffineExpr, 4> results;
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ if (idx == index)
+ continue;
+ // Re-insert remaining indices, but renamed when occurring
+ // after the removed index.
+ auto targetExpr =
+ getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext());
+ results.push_back(targetExpr);
+ }
+ // Since (...) -> () cannot be represented properly,
+ // we resort to an empty map when this situation happens.
+ return results.empty() ? AffineMap::get(rewriter.getContext())
+ : AffineMap::get(map.getNumDims() - 1, 0, results);
+ }
+
+ // Helper to drop dimension from vector type.
+ static Type adjustType(VectorType tp, int64_t index) {
+ int64_t rank = tp.getRank();
+ Type eltType = tp.getElementType();
+ if (rank == 1) {
+ assert(index == 0 && "index for scalar result out of bounds");
+ return eltType;
+ }
+ SmallVector<int64_t, 4> adjustedShape;
+ for (int64_t i = 0; i < rank; ++i) {
+ // Omit dimension at the given index.
+ if (i == index)
+ continue;
+ // Otherwise, add dimension back.
+ adjustedShape.push_back(tp.getDimSize(i));
+ }
+ return VectorType::get(adjustedShape, eltType);
+ }
+
+ // Helper method to possibly drop a dimension in a load.
+ // TODO(ajcbik): use a reshaping vector load (and share lowering code)
+ static Value reshapeLoad(Location loc, Value val, VectorType type,
+ int64_t index, int64_t pos,
+ PatternRewriter &rewriter) {
+ if (index == -1)
+ return val;
+ Type lowType = adjustType(type, 0);
+ // At extraction dimension?
+ if (index == 0) {
+ auto posAttr = rewriter.getI64ArrayAttr(pos);
+ return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
+ }
+ // Unroll leading dimensions.
+ VectorType vType = lowType.cast<VectorType>();
+ VectorType resType = adjustType(type, index).cast<VectorType>();
+ Value result = zeroVector(loc, resType, rewriter);
+ for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
+ auto posAttr = rewriter.getI64ArrayAttr(d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+ Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
+ result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
+ posAttr);
+ }
+ return result;
+ }
+
+ // Helper method to possibly drop a dimension in a store.
+ // TODO(ajcbik): use a reshaping vector store (and share lowering code)
+ static Value reshapeStore(Location loc, Value val, Value result,
+ VectorType type, int64_t index, int64_t pos,
+ PatternRewriter &rewriter) {
+ // Unmodified?
+ if (index == -1)
+ return val;
+ // At insertion dimension?
+ if (index == 0) {
+ auto posAttr = rewriter.getI64ArrayAttr(pos);
+ return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
+ }
+ // Unroll leading dimensions.
+ Type lowType = adjustType(type, 0);
+ VectorType vType = lowType.cast<VectorType>();
+ Type insType = adjustType(vType, 0);
+ for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
+ auto posAttr = rewriter.getI64ArrayAttr(d);
+ Value ext =
+ rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
+ Value ins =
+ rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+ Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+ result =
+ rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+ }
+ return result;
+ }
};
} // namespace
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 6c4cb5f4bfd0..f781e37d586b 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -14,7 +14,7 @@
// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00>
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
// CHECK: return %[[R]] : f32
@@ -24,3 +24,148 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
: vector<4xf32>, vector<4xf32> into f32
return %0 : f32
}
+
+#matvec_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j)>,
+ affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract2
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
+// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[T9]] : vector<2xf32>
+
+func @extract_contract2(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+#vecmat_accesses = [
+ affine_map<(i, j) -> (j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i)>
+]
+#vecmat_trait = {
+ indexing_maps = #vecmat_accesses,
+ iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract3
+// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
+// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[T9]] : vector<2xf32>
+
+func @extract_contract3(%arg0: vector<3xf32>,
+ %arg1: vector<2x3xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
+ : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract4
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
+// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
+// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
+// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
+// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
+// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
+// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
+// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
+// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
+// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
+// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
+// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK: return %[[T45]] : vector<2x2xf32>
+
+func @extract_contract4(%arg0: vector<2x2xf32>,
+ %arg1: vector<2x2xf32>,
+ %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
More information about the Mlir-commits
mailing list