[Mlir-commits] [mlir] b956f04 - [mlir][linalg] Vectorize directly to a named contraction (#147296)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 22:42:05 PDT 2025
Author: Adam Siemieniuk
Date: 2025-07-22T07:42:02+02:00
New Revision: b956f049b186fafafebc88b861982644ec3f5291
URL: https://github.com/llvm/llvm-project/commit/b956f049b186fafafebc88b861982644ec3f5291
DIFF: https://github.com/llvm/llvm-project/commit/b956f049b186fafafebc88b861982644ec3f5291.diff
LOG: [mlir][linalg] Vectorize directly to a named contraction (#147296)
Extends linalg vectorizer with a path to lower contraction ops directly
into `vector.contract`.
The direct rewriting preserves high-level op semantics and provides more
progressive lowering compared to reconstructing contraction back from
multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well
defined semantics to avoid complex precondition verification.
The new path is optional and disabled by default to avoid changing the
default vectorizer behavior.
Added:
mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bafeca924e4c5..8d45c40a93e2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2435,6 +2435,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
+ OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
let results = (outs);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9e62d0dcc7890..38e53648e7c34 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,12 +876,15 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
- bool assumeDynamicDimsMatchVecSizes = false);
+ bool assumeDynamicDimsMatchVecSizes = false,
+ bool createNamedContraction = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index d68138acec0db..7cd70e42d363c 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -227,7 +227,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
- bool useInBoundsInsteadOfMasking = false);
+ bool useInBoundsInsteadOfMasking = false,
+ ArrayRef<bool> scalableDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c959310136319..109e5b7f95ec0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,8 +3920,10 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false), false,
- getAssumeDynamicDimsMatchVecSizes().value_or(false));
+ getVectorizeNdExtract().value_or(false),
+ /*flatten1DDepthwiseConv=*/false,
+ getAssumeDynamicDimsMatchVecSizes().value_or(false),
+ getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4add50f4b36e5..77c85abab9aa0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;
// Compute the mask and mask the write Op.
- auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+ vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
+ isa<MemRefType>(dest.getType())
+ ? memref::getMixedSizes(builder, loc, dest)
+ : tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());
@@ -2118,6 +2122,92 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction op.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ // For simplicity, contraction vectorization is limited to linalg named ops.
+ // Generic op is ignored as not every arbitrary contraction body can be
+ // expressed by a vector.contract.
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind) {
+ LDBG("Failed to determine contraction combining kind.\n");
+ return failure();
+ }
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be decomposed before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
+ LDBG("Contractions with broadcasts are not supported.\n");
+ return failure();
+ }
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Operation *contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+ contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
+
+ // Store result.
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, contractOp->getResult(0), outOperand->get());
+
+ // Finalize.
+ if (!write->getResults().empty())
+ newResults.push_back(write->getResult(0));
+
+ return success();
+}
+
namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
@@ -2557,7 +2647,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
- bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
+ bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
+ bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2604,6 +2695,11 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return failure();
}
+ if (createNamedContraction &&
+ isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
+ results);
+
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e4984582b373..9b055853fc8b0 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
- bool useInBoundsInsteadOfMasking) {
+ bool useInBoundsInsteadOfMasking,
+ ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+ auto vectorType =
+ VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
+ isa<MemRefType>(source.getType())
+ ? memref::getMixedSizes(builder, loc, source)
+ : tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+ auto maskType =
+ VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
new file mode 100644
index 0000000000000..d8f897cca958d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir
@@ -0,0 +1,484 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// Tests for vectorizing operations implementing contraction op interface.
+/// Ops implementing the contraction interface are vectorized directly to their
+/// vector dialect named counterparts.
+///----------------------------------------------------------------------------------------
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+
+/// Get the contraction dimensions
+// CHECK: %[[MATMUL_DIM_M_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[MATMUL_DIM_M:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_M_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MATMUL_DIM_N_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[MATMUL_DIM_N:.*]] = tensor.dim %[[B]], %[[MATMUL_DIM_N_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MATMUL_DIM_K_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[MATMUL_DIM_K:.*]] = tensor.dim %[[A]], %[[MATMUL_DIM_K_IDX]] : tensor<?x?xf32>
+
+/// Create a mask for the A matrix
+// CHECK: %[[A_OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[A_DIM_M_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[A_DIM_M:.*]] = tensor.dim %[[A]], %[[A_DIM_M_IDX]] : tensor<?x?xf32>
+// CHECK: %[[A_DIM_K_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[A_DIM_K:.*]] = tensor.dim %[[A]], %[[A_DIM_K_IDX]] : tensor<?x?xf32>
+// CHECK: %[[LOAD_A_MASK:.*]] = vector.create_mask
+// CHECK-SAME: %[[A_DIM_M]], %[[A_DIM_K]] : vector<8x4xi1>
+/// Read the A matrix
+// CHECK: %[[LOAD_A:.*]] = vector.mask %[[LOAD_A_MASK]]
+// CHECK-SAME: { vector.transfer_read %[[A]]{{\[}}%[[A_OFFSET]], %[[A_OFFSET]]{{\]}}
+// CHECK-SAME: : tensor<?x?xf32>, vector<8x4xf32> }
+// CHECK-SAME: : vector<8x4xi1> -> vector<8x4xf32>
+
+/// Create a mask for the B matrix
+// CHECK: %[[B_OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[B_DIM_K_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[B_DIM_K:.*]] = tensor.dim %[[B]], %[[B_DIM_K_IDX]] : tensor<?x?xf32>
+// CHECK: %[[B_DIM_N_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[B_DIM_N:.*]] = tensor.dim %[[B]], %[[B_DIM_N_IDX]] : tensor<?x?xf32>
+// CHECK: %[[LOAD_B_MASK:.*]] = vector.create_mask
+// CHECK-SAME: %[[B_DIM_K]], %[[B_DIM_N]] : vector<4x16xi1>
+/// Read the B matrix
+// CHECK: %[[LOAD_B:.*]] = vector.mask %[[LOAD_B_MASK]]
+// CHECK-SAME: { vector.transfer_read %[[B]]{{\[}}%[[B_OFFSET]], %[[B_OFFSET]]{{\]}}
+// CHECK-SAME: : tensor<?x?xf32>, vector<4x16xf32> }
+// CHECK-SAME: : vector<4x16xi1> -> vector<4x16xf32>
+
+/// Create a mask for the C matrix
+// CHECK: %[[C_OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[C_DIM_M_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[C_DIM_M:.*]] = tensor.dim %[[C]], %[[C_DIM_M_IDX]] : tensor<?x?xf32>
+// CHECK: %[[C_DIM_N_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[C_DIM_N:.*]] = tensor.dim %[[C]], %[[C_DIM_N_IDX]] : tensor<?x?xf32>
+// CHECK: %[[LOAD_C_MASK:.*]] = vector.create_mask
+// CHECK-SAME: %[[C_DIM_M]], %[[C_DIM_N]] : vector<8x16xi1>
+/// Read the C matrix
+// CHECK: %[[LOAD_C:.*]] = vector.mask %[[LOAD_C_MASK]]
+// CHECK-SAME: { vector.transfer_read %[[C]]{{\[}}%[[C_OFFSET]], %[[C_OFFSET]]{{\]}}
+// CHECK-SAME: : tensor<?x?xf32>, vector<8x16xf32> }
+// CHECK-SAME: : vector<8x16xi1> -> vector<8x16xf32>
+
+/// Create a mask for the contraction
+// CHECK: %[[CONTRACTION_MASK:.*]] = vector.create_mask
+// CHECK-SAME: %[[MATMUL_DIM_M]], %[[MATMUL_DIM_N]], %[[MATMUL_DIM_K]]
+// CHECK-SAME: : vector<8x16x4xi1>
+/// Perform the contraction
+// CHECK: %[[D:.*]] = vector.mask %[[CONTRACTION_MASK]]
+// CHECK-SAME: { vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK-SAME: } : vector<8x16x4xi1> -> vector<8x16xf32>
+
+/// Create a mask for the result
+// CHECK: %[[D_OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[D_DIM_M_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[D_DIM_M:.*]] = tensor.dim %[[C]], %[[D_DIM_M_IDX]] : tensor<?x?xf32>
+// CHECK: %[[D_DIM_N_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[D_DIM_N:.*]] = tensor.dim %[[C]], %[[D_DIM_N_IDX]] : tensor<?x?xf32>
+// CHECK: %[[LOAD_D_MASK:.*]] = vector.create_mask
+// CHECK-SAME: %[[D_DIM_M]], %[[D_DIM_N]] : vector<8x16xi1>
+/// Write the result
+// CHECK: vector.mask %[[LOAD_D_MASK]]
+// CHECK-SAME: { vector.transfer_write %[[D]], %[[C]]{{\[}}%[[D_OFFSET]], %[[D_OFFSET]]{{\]}}
+// CHECK-SAME: : vector<8x16xf32>, tensor<?x?xf32> }
+// CHECK-SAME: : vector<8x16xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>) {
+ linalg.matmul
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_scalable(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32> }
+// CHECK-SAME: : vector<8x4xi1> -> vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32> }
+// CHECK-SAME: : vector<4x[16]xi1> -> vector<4x[16]xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32> }
+// CHECK-SAME: : vector<8x[16]xi1> -> vector<8x[16]xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK-SAME: } : vector<8x[16]x4xi1> -> vector<8x[16]xf32>
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32> }
+// CHECK-SAME: : vector<8x[16]xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<4x8xf32>, tensor<16x4xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose(
+// CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>, %[[B:.*]]: tensor<16x4xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8xf32>, vector<4x8xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<16x4xf32>, vector<16x4xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_transpose(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<4x8xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<16x4xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+/// Contractions with arbitrarty broadcasts are not supported in contraction interface
+/// vectorization.
+/// Dimension broadcasts are expected to be decomposed first which removes ambiguity
+/// caused by possible variants of dimensions materialization.
+/// For example, whether the below target LHS input layout is (m, k) or (k, m).
+
+func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf16>, tensor<4x16xf16>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_mixed_precision(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf16>, %[[B:.*]]: tensor<4x16xf16>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf16>, vector<8x4xf16>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf16>, vector<4x16xf16>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
+ %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> {
+ %0 = linalg.batch_matmul
+ ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>)
+ outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32>
+ return %0 : tensor<3x8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<3x8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_reduce_matmul(%A: tensor<3x8x4xf32>, %B: tensor<3x4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.batch_reduce_matmul
+ ins(%A, %B : tensor<3x8x4xf32>, tensor<3x4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf32>, %[[B:.*]]: tensor<3x4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf32>, vector<3x8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf32>, vector<3x4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @contract(%A: tensor<4x8x2xf32>, %B: tensor<8x16x2xf32>,
+ %C: tensor<4x16xf32>) -> tensor<4x16xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k, kk) -> (m, k, kk)>,
+ affine_map<(m, n, k, kk) -> (k, n, kk)>,
+ affine_map<(m, n, k, kk) -> (m, n)>]
+ ins(%A, %B : tensor<4x8x2xf32>, tensor<8x16x2xf32>)
+ outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract(
+// CHECK-SAME: %[[A:.*]]: tensor<4x8x2xf32>, %[[B:.*]]: tensor<8x16x2xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<4x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf32>, vector<4x8x2xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf32>, vector<8x16x2xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+/// Generic can represent contractions but it does not implement contraction interface.
+/// Thus, direct lowering to vector.contract is not supported.
+/// Vectorization still works and applies generic rewrite logic.
+
+func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C : tensor<8x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @negative_generic(
+// CHECK-NOT: vector.contract
+// CHECK: vector.multi_reduction
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list