[Mlir-commits] [mlir] [mlir][linalg] Vectorize directly to a named contraction (PR #147296)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 7 06:29:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Extends linalg vectorizer with a path to lower contraction ops directly into `vector.contract`.
The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.
The new path is optional and disabled by default to avoid changing the default vectorizer behavior.
---
Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4-1)
- (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+2-1)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+103-7)
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+9-4)
- (added) mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir (+400)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+ OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+ OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+ bool createNamedContraction = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
- bool useInBoundsInsteadOfMasking = false);
+ bool useInBoundsInsteadOfMasking = false,
+ ArrayRef<bool> scalableDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false));
+ getVectorizeNdExtract().value_or(false),
+ getFlatten1DDepthwiseConv().value_or(false),
+ getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;
// Compute the mask and mask the write Op.
- auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+ vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
+ isa<MemRefType>(dest.getType())
+ ? memref::getMixedSizes(builder, loc, dest)
+ : tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind)
+ return failure();
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be materialized before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return failure();
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Value contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+ // Store result.
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+ // Finalize.
+ if (!write->getResults().empty())
+ newResults.push_back(write->getResult(0));
+
+ return success();
+}
+
namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv, bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
+ // For simplicity, contraction vectorization is limited to linalg
+ // named ops. Generic op is ignored as not every arbitrary
+ // contraction body can be expressed by a vector.contract.
+ if (createNamedContraction &&
+ isa<ContractionOpInterface>(linalgOp.getOperation())) {
+ // Attempt vectorizing directly into a named contraction.
+ // In case of failure, fall back to the generic path.
+ LogicalResult res = vectorizeAsLinalgContraction(
+ rewriter, state, linalgOp, results);
+ if (succeeded(res))
+ return success();
+
+ LDBG("Failed to vectorize as a named contraction.\n");
+ }
+
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
- bool useInBoundsInsteadOfMasking) {
+ bool useInBoundsInsteadOfMasking,
+ ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+ auto vectorType =
+ VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
+ isa<MemRefType>(source.getType())
+ ? memref::getMixedSizes(builder, loc, source)
+ : tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+ auto maskType =
+ VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>) {
+ linalg.matmul
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/147296
More information about the Mlir-commits
mailing list