[Mlir-commits] [mlir] f245b7a - [mlir][Linalg] Generalize the definition of a Linalg contraction.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Feb 3 23:55:09 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-04T07:50:44Z
New Revision: f245b7ad36ff8bd85cddbe9784f7efe6dee577c0
URL: https://github.com/llvm/llvm-project/commit/f245b7ad36ff8bd85cddbe9784f7efe6dee577c0
DIFF: https://github.com/llvm/llvm-project/commit/f245b7ad36ff8bd85cddbe9784f7efe6dee577c0.diff
LOG: [mlir][Linalg] Generalize the definition of a Linalg contraction.
This revision defines a Linalg contraction in general terms:
1. Has 2 input and 1 output shapes.
2. Has at least one reduction dimension.
3. Has only projected permutation indexing maps.
4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
(AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
operations that may change the type (e.g. for mixed-precision).
As a consequence, when vectorization of such an op occurs, the only special
behavior is that the (unique) MulOpType is vectorized into a
`vector.contract`. All other ops are handled in a generic fashion.
In the future, we may wish to allow more input arguments and elementwise and
constant operations that do not involve the reduction dimension(s).
A test is added to demonstrate the proper vectorization of matmul_i8_i8_i32.
Differential revision: https://reviews.llvm.org/D95939
Added:
Modified:
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 7791ed0d5eee..5e577d778210 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -77,7 +77,7 @@ func @main() {
scf.for %arg0 = %c0 to %iters step %c1 {
// linalg.matmul writes %C in place, need to reset it to zero every time.
// This is accounts for about 10-15% perf hit on small sizes.
- // Once linalg on tensors is ready, fusing fill at teh register level will
+ // Once linalg on tensors is ready, fusing fill at the register level will
// be easy.
%z = constant 0.0 : !elem_type_c
linalg.fill(%C, %z) : !row_major_C, !elem_type_c
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
index e454c7cb8160..de4e51bd8c0e 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
@@ -75,7 +75,7 @@ func @main() {
scf.for %arg0 = %c0 to %iters step %c1 {
// linalg.matmul writes %C in place, need to reset it to zero every time.
// This is accounts for about 10-15% perf hit on small sizes.
- // Once linalg on tensors is ready, fusing fill at teh register level will
+ // Once linalg on tensors is ready, fusing fill at the register level will
// be easy.
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
index 287cb1c24059..95fc57506c43 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
@@ -84,7 +84,7 @@ func @main() {
scf.for %arg0 = %c0 to %iters step %c1 {
// linalg.matmul writes %C in place, need to reset it to zero every time.
// This is accounts for about 10-15% perf hit on small sizes.
- // Once linalg on tensors is ready, fusing fill at teh register level will
+ // Once linalg on tensors is ready, fusing fill at the register level will
// be easy.
linalg.fill(%C, %f0) : !row_major_C, !elem_type_c
call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
index 961a83fd3f57..abfb14739e25 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
@@ -1,12 +1,11 @@
// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
-// TODO: extend vectorization with interfaces so that it works with sexti
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16 vectorize" | \
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
-// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -mlir-disable-threading | \
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
// Activate to dump assembly
// R_UN: -dump-object-file -object-filename=/tmp/a.o \
@@ -18,9 +17,9 @@
!elem_type_a = type i8
!elem_type_b = type i8
!elem_type_c = type i32
-!row_major_A = type memref<${M}x${K}x!elem_type_a>
-!row_major_B = type memref<${K}x${N}x!elem_type_b>
-!row_major_C = type memref<${M}x${N}x!elem_type_c>
+!row_major_A = type memref<24x64x!elem_type_a>
+!row_major_B = type memref<64x192x!elem_type_b>
+!row_major_C = type memref<24x192x!elem_type_c>
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
// TODO: activate manually for now.
@@ -33,9 +32,9 @@ func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
func @print_perf(%iters: index, %total_time: f64) {
%c2 = constant 2 : index
- %cM = constant ${M} : index
- %cN = constant ${N} : index
- %cK = constant ${K} : index
+ %cM = constant 24 : index
+ %cN = constant 192 : index
+ %cK = constant 64 : index
%mn = muli %cM, %cN : index
%mnk = muli %mn, %cK : index
@@ -65,7 +64,7 @@ func @main() {
%c0 = constant 0: index
%c1 = constant 1: index
- %iters = constant ${ITERS}: index
+ %iters = constant 100: index
/// Run and dump performance for matmul.
/// Preheating run:
@@ -77,7 +76,7 @@ func @main() {
scf.for %arg0 = %c0 to %iters step %c1 {
// linalg.matmul writes %C in place, need to reset it to zero every time.
// This is accounts for about 10-15% perf hit on small sizes.
- // Once linalg on tensors is ready, fusing fill at teh register level will
+ // Once linalg on tensors is ready, fusing fill at the register level will
// be easy.
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index fc1266e3608a..fb9d452cfcef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -38,6 +38,70 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
+/// Return true if the use-def chain from `v` to `from` consists of 0 or more
+/// unary single-operand operations.
+// TODO: relax to multi-operands with constants, which are technically unary ops
+// as needed (e.g. add5).
+static bool isChainOfUnaryOpsFrom(Value v, Value from) {
+ while (v != from) {
+ Operation *op = v.getDefiningOp();
+ if (!op || op->getNumOperands() != 1)
+ return false;
+ v = op->getOperand(0);
+ };
+ return true;
+}
+
+/// Return the unique instance of OpType in `block` if it is indeed unique.
+/// Return null if none or more than 1 instances exist.
+template <typename OpType>
+static OpType getSingleOpOfType(Block &block) {
+ OpType res;
+ block.walk([&](OpType op) {
+ if (res) {
+ res = nullptr;
+ return WalkResult::interrupt();
+ }
+ res = op;
+ return WalkResult::advance();
+ });
+ return res;
+}
+
+/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
+/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
+/// unary operations that may change the type.
+template <typename AddOpType, typename MulOpType>
+static bool isAddMul(Block &block) {
+ if (block.getNumArguments() != 3)
+ return false;
+ Operation *yieldOp = block.getTerminator();
+ if (yieldOp->getNumOperands() != 1)
+ return false;
+
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isAddMul: "; block.dump());
+ AddOpType addOp = getSingleOpOfType<AddOpType>(block);
+ MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
+ if (!addOp || !mulOp)
+ return false;
+
+ Value argA = block.getArgument(0), argB = block.getArgument(1);
+ Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+ Value mul = mulOp->getResult(0);
+ Value argC = block.getArgument(2);
+ Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
+ Value add = addOp->getResult(0);
+ Value res = yieldOp->getOperand(0);
+ // Result traces back to add.
+ auto un = isChainOfUnaryOpsFrom;
+ bool success = un(res, add);
+ // One of the operands of add traces back to argC, the other to the mul.
+ success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
+ // One of the operands of mul traces back to argA, the other to argB.
+ success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
+ return success;
+}
+
/// Helper data structure to represent the result of vectorization.
/// In certain specific cases, like terminators, we do not want to propagate/
enum VectorizationStatus {
@@ -146,7 +210,7 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
results.push_back(result);
}
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
-};
+}
/// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows:
@@ -305,55 +369,34 @@ static LogicalResult vectorizeAsLinalgGeneric(
return success();
}
-/// Detect whether `r` exactly computes a floating-point or integer
-/// multiply-accumulate.
-static bool hasMultiplyAddBody(Region &r) {
- if (!llvm::hasSingleElement(r))
- return false;
- if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
- return false;
-
- using mlir::matchers::m_Val;
- auto a = m_Val(r.getArgument(0));
- auto b = m_Val(r.getArgument(1));
- auto c = m_Val(r.getArgument(2));
- // TODO: Update this detection once we have matcher support for specifying
- // that any permutation of operands matches.
- auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
- auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
- auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
- auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
- auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
- auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
- auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
- auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
- return pattern1.match(&r.front().back()) ||
- pattern2.match(&r.front().back()) ||
- pattern3.match(&r.front().back()) ||
- pattern4.match(&r.front().back()) ||
- pattern5.match(&r.front().back()) ||
- pattern6.match(&r.front().back()) ||
- pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
-}
-
/// Detect whether the LinalgOp `op` is a contraction.
-// TODO: Should be Tablegen'd from a single source that generates the op itself.
+/// A Linalg contraction is defined in general terms:
+/// 1. Has 2 input and 1 output shapes.
+/// 2. Has at least one reduction dimension.
+/// 3. Has only projected permutation indexing maps.
+/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
+/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
+/// operations that may change the type (e.g. for mixed-precision).
+/// As a consequence, when vectorization of such an op occurs, the only special
+/// behavior is that the (unique) MulOpType is vectorized into a
+/// `vector.contract`. All other ops are handled in a generic fashion.
+/// In the future, we may wish to allow more input arguments and elementwise and
+/// constant operations that do not involve the reduction dimension(s).
static LogicalResult isContraction(Operation *op) {
- // TODO: interface for named ops.
- if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatmulColumnMajorOp,
- linalg::MatvecOp, linalg::VecmatOp, linalg::DotOp>(op))
- return success();
-
- auto genericOp = dyn_cast<linalg::GenericOp>(op);
- if (!genericOp)
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isContraction: "; op->dump());
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp)
return failure();
- auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
+ auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
return success(
- genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
+ linalgOp.getNumInputs() == 2 && linalgOp.getNumOutputs() == 1 &&
+ linalgOp.getNumReductionLoops() > 0 &&
llvm::all_of(mapRange,
[](AffineMap m) { return m.isProjectedPermutation(); }) &&
- hasMultiplyAddBody(genericOp.region()));
+ // TODO: more fields than add/mul.
+ (isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) ||
+ isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front())));
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -382,7 +425,7 @@ static bool isElementwise(Operation *op) {
if (!genericOp.getOutputIndexingMap(i).isIdentity())
return false;
}
- // Currently limit the input indexing map to minor identity as other
+ // Currently bound the input indexing map to minor identity as other
// permutations might require adding transpose ops to convert the vector read
// to the right shape.
for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
@@ -479,6 +522,150 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
"Unexpected vectorization failed despite preconditions");
}
+//----------------------------------------------------------------------------//
+// Misc. conv vectorization patterns.
+//----------------------------------------------------------------------------//
+// TODO: cleanup all this.
+template <class ConvOp, int N>
+LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
+ ConvOp op, PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ MLIRContext *context = op.getContext();
+ edsc::ScopedContext scope(rewriter, loc);
+
+ ShapedType inShapeType = op.getInputShapedType(0);
+ ShapedType kShapeType = op.getInputShapedType(1);
+
+ ArrayRef<int64_t> inShape = inShapeType.getShape();
+ ArrayRef<int64_t> kShape = kShapeType.getShape();
+
+ if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
+ return failure();
+
+ SmallVector<AffineExpr, 4> mapping;
+ SmallVector<int64_t, 4> vectorDims;
+ // Fail to apply when the size of not vectorized dimension is not 1.
+ for (unsigned i = 0; i < N; i++) {
+ if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
+ return failure();
+
+ if (mask[i] && inShape[i] != kShape[i])
+ return failure();
+
+ if (mask[i]) {
+ mapping.push_back(getAffineDimExpr(i, context));
+ vectorDims.push_back(inShape[i]);
+ }
+ }
+
+ Value input = op.getInput(0);
+ Value kernel = op.getInput(1);
+ Value output = op.getOutputBuffer(0);
+
+ unsigned rank = inShapeType.getRank();
+ unsigned numDims = mapping.size();
+ Type elemType = inShapeType.getElementType();
+
+ auto map = AffineMap::get(rank, 0, mapping, context);
+ SmallVector<Value, 4> zeros(rank, std_constant_index(0));
+ auto vecType = VectorType::get(vectorDims, elemType);
+
+ auto inputVec = vector_transfer_read(vecType, input, zeros, map);
+ auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
+
+ auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
+
+ std::array<AffineMap, 3> indexingMaps{
+ AffineMap::getMultiDimIdentityMap(numDims, context),
+ AffineMap::getMultiDimIdentityMap(numDims, context),
+ AffineMap::get(numDims, 0, {}, context)};
+
+ std::vector<StringRef> iteratorTypes(numDims, "reduction");
+
+ auto result = rewriter.create<vector::ContractionOp>(
+ loc, inputVec, kernelVec, acc,
+ rewriter.getAffineMapArrayAttr(indexingMaps),
+ rewriter.getStrArrayAttr(iteratorTypes));
+
+ rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
+ rewriter.eraseOp(op);
+ return success();
+}
+
+using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
+
+/// Inserts tiling, promotion and vectorization pattern for ConvOp
+/// conversion into corresponding pattern lists.
+template <typename ConvOp, unsigned N>
+static void
+populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
+ OwningRewritePatternList &promotionPatterns,
+ OwningRewritePatternList &vectorizationPatterns,
+ ArrayRef<int64_t> tileSizes,
+ MLIRContext *context) {
+ if (tileSizes.size() < N)
+ return;
+
+ constexpr static StringRef kTiledMarker = "TILED";
+ constexpr static StringRef kPromotedMarker = "PROMOTED";
+ tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
+ context, LinalgTilingOptions().setTileSizes(tileSizes),
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get(kTiledMarker, context)));
+
+ promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
+ context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
+ LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
+ Identifier::get(kPromotedMarker, context)));
+
+ SmallVector<bool, 4> mask(N);
+ int offset = tileSizes.size() - N;
+ std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
+ [](int64_t i) -> bool { return i > 1; });
+
+ vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
+}
+
+void mlir::linalg::populateConvVectorizationPatterns(
+ MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+ ArrayRef<int64_t> tileSizes) {
+ OwningRewritePatternList tiling, promotion, vectorization;
+ populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
+ tileSizes, context);
+
+ populateVectorizationPatterns<ConvNDHWCOp, 5>(
+ tiling, promotion, vectorization, tileSizes, context);
+
+ populateVectorizationPatterns<ConvNCDHWOp, 5>(
+ tiling, promotion, vectorization, tileSizes, context);
+
+ patterns.push_back(std::move(tiling));
+ patterns.push_back(std::move(promotion));
+ patterns.push_back(std::move(vectorization));
+}
+
+//----------------------------------------------------------------------------//
+// Forwarding patterns
+//----------------------------------------------------------------------------//
+
/// Check whether there is any interleaved use of any `values` between `firstOp`
/// and `secondOp`. Conservatively return `true` if any op or value is in a
///
diff erent block.
@@ -649,139 +836,3 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
return success();
}
-
-template <class ConvOp, int N>
-LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
- ConvOp op, PatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- MLIRContext *context = op.getContext();
- edsc::ScopedContext scope(rewriter, loc);
-
- ShapedType inShapeType = op.getInputShapedType(0);
- ShapedType kShapeType = op.getInputShapedType(1);
-
- ArrayRef<int64_t> inShape = inShapeType.getShape();
- ArrayRef<int64_t> kShape = kShapeType.getShape();
-
- if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
- return failure();
-
- SmallVector<AffineExpr, 4> mapping;
- SmallVector<int64_t, 4> vectorDims;
- // Fail to apply when the size of not vectorized dimension is not 1.
- for (unsigned i = 0; i < N; i++) {
- if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
- return failure();
-
- if (mask[i] && inShape[i] != kShape[i])
- return failure();
-
- if (mask[i]) {
- mapping.push_back(getAffineDimExpr(i, context));
- vectorDims.push_back(inShape[i]);
- }
- }
-
- Value input = op.getInput(0);
- Value kernel = op.getInput(1);
- Value output = op.getOutputBuffer(0);
-
- unsigned rank = inShapeType.getRank();
- unsigned numDims = mapping.size();
- Type elemType = inShapeType.getElementType();
-
- auto map = AffineMap::get(rank, 0, mapping, context);
- SmallVector<Value, 4> zeros(rank, std_constant_index(0));
- auto vecType = VectorType::get(vectorDims, elemType);
-
- auto inputVec = vector_transfer_read(vecType, input, zeros, map);
- auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
-
- auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
-
- std::array<AffineMap, 3> indexingMaps{
- AffineMap::getMultiDimIdentityMap(numDims, context),
- AffineMap::getMultiDimIdentityMap(numDims, context),
- AffineMap::get(numDims, 0, {}, context)};
-
- std::vector<StringRef> iteratorTypes(numDims, "reduction");
-
- auto result = rewriter.create<vector::ContractionOp>(
- loc, inputVec, kernelVec, acc,
- rewriter.getAffineMapArrayAttr(indexingMaps),
- rewriter.getStrArrayAttr(iteratorTypes));
-
- rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
- rewriter.eraseOp(op);
- return success();
-}
-
-using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
-
-/// Inserts tiling, promotion and vectorization pattern for ConvOp
-/// conversion into corresponding pattern lists.
-template <typename ConvOp, unsigned N>
-static void
-populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
- OwningRewritePatternList &promotionPatterns,
- OwningRewritePatternList &vectorizationPatterns,
- ArrayRef<int64_t> tileSizes,
- MLIRContext *context) {
- if (tileSizes.size() < N)
- return;
-
- constexpr static StringRef kTiledMarker = "TILED";
- constexpr static StringRef kPromotedMarker = "PROMOTED";
- tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
- context, LinalgTilingOptions().setTileSizes(tileSizes),
- LinalgTransformationFilter(ArrayRef<Identifier>{},
- Identifier::get(kTiledMarker, context)));
-
- promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
- context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
- Identifier::get(kPromotedMarker, context)));
-
- SmallVector<bool, 4> mask(N);
- int offset = tileSizes.size() - N;
- std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
- [](int64_t i) -> bool { return i > 1; });
-
- vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
-}
-
-void mlir::linalg::populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
- ArrayRef<int64_t> tileSizes) {
- OwningRewritePatternList tiling, promotion, vectorization;
- populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
-
- populateVectorizationPatterns<ConvNDHWCOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
-
- populateVectorizationPatterns<ConvNCDHWOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
-
- patterns.push_back(std::move(tiling));
- patterns.push_back(std::move(promotion));
- patterns.push_back(std::move(vectorization));
-}
diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
index 64d3405f016d..21aba6cdaf7b 100644
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse | FileCheck %s
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file
+// | FileCheck %s
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
@@ -6,16 +7,11 @@
// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)>
// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
- outs(%arg2 : memref<?xf32>)
- return
-}
-
// CHECK-LABEL: @conv_1d
// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
+func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
// CHECK-DAG: %[[c12:.*]] = constant 12 : index
// CHECK-DAG: %[[c4:.*]] = constant 4 : index
// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
@@ -50,3 +46,8 @@ func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>)
// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] {
// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref<?xf32>
// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref<?xf32, #[[$map1]]>
+ linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
+ outs(%arg2 : memref<?xf32>)
+ return
+}
+
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index aa249542a07d..3904353287c5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,8 +1,6 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// -----
// CHECK-LABEL: contraction_dot
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
@@ -13,6 +11,8 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
return
}
+// -----
+
// CHECK-LABEL: contraction_matvec
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// CHECK: vector.contract
@@ -22,6 +22,8 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
return
}
+// -----
+
// CHECK-LABEL: contraction_matmul
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: vector.contract
@@ -31,6 +33,8 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
return
}
+// -----
+
// CHECK-LABEL: contraction_batch_matmul
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: vector.contract
@@ -41,6 +45,8 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
return
}
+// -----
+
#matmul_trait = {
args_in = 2,
args_out = 1,
@@ -51,8 +57,20 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
],
iterator_types = ["parallel", "parallel", "reduction"]
}
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @vectorization_test
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+ // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]]
+ // CHECK-SAME: vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
outs(%C : memref<8x32xf32>) {
@@ -63,15 +81,33 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
}
return
}
-// CHECK-LABEL: func @vectorization_test
-// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
-// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
+// -----
+
+#matmul_trait = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @vectorization_test_integer
func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
%C: memref<8x32xi32>) {
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+ // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]],
+ // CHECK-SAME: vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
outs(%C : memref<8x32xi32>) {
@@ -82,58 +118,71 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
}
return
}
-// CHECK-LABEL: func @vectorization_test_integer
-// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
-// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+// -----
+
+// CHECK-LABEL: func @vectorization_test_2
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
+ // CHECK: vector.contract {{.*}} :
+ // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
linalg.matmul
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
outs(%C: memref<8x32xf32>)
return
}
-// CHECK-LABEL: func @vectorization_test_2
-// CHECK: vector.contract {{.*}} :
-// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+ // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
linalg.fill(%A, %arg0) : memref<8x16xf32>, f32
return
}
-// CHECK-LABEL: func @test_vectorize_fill
-// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
-// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+ // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+ // CHECK: store %[[V]], %[[M]][] : memref<f32>
linalg.fill(%A, %arg0) : memref<f32>, f32
return
}
-// CHECK-LABEL: func @test_vectorize_fill
-// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
-// CHECK: store %[[V]], %[[M]][] : memref<f32>
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy
func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
linalg.copy(%A, %B) : memref<8x16xf32>, memref<8x16xf32>
return
}
-// CHECK-LABEL: func @test_vectorize_copy
-// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
-// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_scalar
func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+ // CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+ // CHECK: store %[[V]], {{.*}} : memref<f32>
linalg.copy(%A, %B) : memref<f32>, memref<f32>
return
}
-// CHECK-LABEL: func @test_vectorize_copy_scalar
-// CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
-// CHECK: store %[[V]], {{.*}} : memref<f32>
-func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
- %arg2: memref<256xf32>, %i: f32) {
+// -----
+
+// CHECK-LABEL: func @generic_vectorize
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
+ // CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
+func @generic_vectorize(%arg0: memref<4x256xf32>,
+ %arg1: memref<4x256xf32>,
+ %arg2: memref<256xf32>, %i: f32) {
+ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+ // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
%c1_f32 = constant 1.0 : f32
linalg.generic {
args_in = 0 : i64,
@@ -159,57 +208,56 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
memref<4x256xf32>, memref<4x256xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+ // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+ // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+ // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
%arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
%arg14 : f32):
+ // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+ // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
%6 = addf %arg4, %arg6 : f32
+ // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
%7 = cmpf ogt, %arg3, %arg6 : f32
+ // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
%8 = constant 2.0 : f32
+ // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
%9 = divf %arg5, %i : f32
+ // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
%10 = exp2 %arg5 : f32
+ // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
%11 = mulf %arg5, %8 : f32
+ // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
%12 = rsqrt %arg5 : f32
+ // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
%13 = select %7, %arg5, %arg6 : f32
+ // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+ // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
%14 = subf %arg5, %arg4 : f32
+ // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
%15 = tanh %arg5 : f32
+ // CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+ // CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
f32, f32, f32, f32, f32, f32, f32, f32
}
return
}
-// CHECK-LABEL: func @generic_vectorize
-// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
-// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
-// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
-// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
-// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
-// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
-// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
-// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
-// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
-// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
-// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
-// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// -----
+
+// CHECK-LABEL: func @generic_vectorize_tensor
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
%arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
%i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
@@ -240,82 +288,105 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
%arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
%arg14 : f32):
+ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+ // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+ // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+ // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+ // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+ // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
%6 = addf %arg4, %arg6 : f32
+ // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
%7 = cmpf ogt, %arg3, %arg6 : f32
+ // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
%8 = constant 2.0 : f32
+ // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
%9 = divf %arg5, %i : f32
+ // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
%10 = exp2 %arg5 : f32
+ // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
%11 = mulf %arg5, %8 : f32
+ // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
%12 = rsqrt %arg5 : f32
+ // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
%13 = select %7, %arg5, %arg6 : f32
+ // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+ // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
%14 = subf %arg5, %arg4 : f32
+ // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
%15 = tanh %arg5 : f32
+ // CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+ // CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
f32, f32, f32, f32, f32, f32, f32, f32
} -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+ // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
}
-// CHECK-LABEL: func @generic_vectorize_tensor
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
-// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
-// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
-// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
-// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
-// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
-// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
-// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
-// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
-// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
-// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
-// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+// -----
+// CHECK-LABEL: func @matmul_tensors
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
func @matmul_tensors(
%arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
-> tensor<8x12xf32> {
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
+ // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+ // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+ // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+ //
+ // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+ // a later canonicalization fuses the add into vector.contract.
+ // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+ // CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
+ // CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
outs(%arg2: tensor<8x12xf32>)
-> tensor<8x12xf32>
+ // CHECK: return %[[W]] : tensor<8x12xf32>
return %0 : tensor<8x12xf32>
}
-// CHECK-LABEL: func @matmul_tensors
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
-// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
-// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
-//
-// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
-// a later canonicalization fuses the add into vector.contract.
-// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
-// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
-// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
-// CHECK: return %[[W]] : tensor<8x12xf32>
+// -----
+
+// CHECK-LABEL: func @matmul_i8_i8_i32
+// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<4x6xi8>
+// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<6x12xi8>
+// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
+func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
+ // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
+ // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
+ // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
+ //
+ // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+ // a later canonicalization fuses the add into vector.contract.
+ // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]]
+ // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
+ // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
+ // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
+ // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
+ // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32>
+ linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)
+ outs(%c: memref<4x12xi32>)
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 126bbc3639af..27ca9942a74c 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -493,9 +493,11 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
OwningRewritePatternList patterns;
+ // TODO: remove all this in favor of a single LinalgOp.
patterns.insert<
LinalgVectorizationPattern<BatchMatmulOp>,
LinalgVectorizationPattern<MatmulOp>,
+ LinalgVectorizationPattern<MatmulI8I8I32Op>,
LinalgVectorizationPattern<MatvecOp>,
LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
More information about the Mlir-commits
mailing list