[Mlir-commits] [mlir] [mlir][linalg] Vectorize directly to a named contraction (PR #147296)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Jul 8 08:04:15 PDT 2025
================
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind)
+ return failure();
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be materialized before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return failure();
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Value contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+ // Store result.
+ Operation *write =
----------------
banach-space wrote:
We do have LICM and it works pretty well. It's much trickier for MemRef though.
https://github.com/llvm/llvm-project/pull/147296
More information about the Mlir-commits
mailing list