[Mlir-commits] [mlir] [mlir][linalg] Vectorize directly to a named contraction (PR #147296)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jul 8 08:04:15 PDT 2025


================
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
----------------
banach-space wrote:

We do have LICM and it works pretty well. It's much trickier for MemRef though.

https://github.com/llvm/llvm-project/pull/147296


More information about the Mlir-commits mailing list