[Mlir-commits] [mlir] [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect (PR #144307)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jul 15 05:24:10 PDT 2025


================
@@ -2035,13 +2000,217 @@ struct VectorScalableStepOpLowering
   }
 };
 
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when vectorContractLowering is set to Matmul and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+    : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
+public:
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  ContractionOpToMatmulOpLowering(
+      vector::VectorContractLowering vectorContractLowering,
+      MLIRContext *context, PatternBenefit benefit = 100)
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %mta = maybe_transpose
+///    %mtb = maybe_transpose
+///    %flattened_a = vector.shape_cast %mta
+///    %flattened_b = vector.shape_cast %mtb
+///    %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
+///    %mtd = vector.shape_cast %flattened_d
+///    %d = maybe_untranspose %mtd
+///    %e = add %c, %d
+/// ```
+//
+/// This only kicks in when vectorContractLowering is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rew) const {
+  // TODO: Support vector.mask.
+  if (maskOp)
+    return failure();
+
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isParallelIterator(iteratorTypes[1]) ||
+      !isReductionIterator(iteratorTypes[2]))
+    return failure();
+
+  Type opResType = op.getType();
+  VectorType vecType = dyn_cast<VectorType>(opResType);
+  if (vecType && vecType.isScalable()) {
+    // Note - this is sufficient to reject all cases with scalable vectors.
+    return failure();
+  }
+
+  Type elementType = op.getLhsType().getElementType();
+  if (!elementType.isIntOrFloat())
+    return failure();
+
+  Type dstElementType = vecType ? vecType.getElementType() : opResType;
+  if (elementType != dstElementType)
+    return failure();
+
+  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+  // Bail out if the contraction cannot be put in this form.
+  MLIRContext *ctx = op.getContext();
+  Location loc = op.getLoc();
+  AffineExpr m, n, k;
+  bindDims(rew.getContext(), m, n, k);
+  // LHS must be A(m, k) or A(k, m).
+  Value lhs = op.getLhs();
+  auto lhsMap = op.getIndexingMapsArray()[0];
+  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+    return failure();
+
+  // RHS must be B(k, n) or B(n, k).
+  Value rhs = op.getRhs();
+  auto rhsMap = op.getIndexingMapsArray()[1];
+  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+    return failure();
+
+  // At this point lhs and rhs are in row-major.
+  VectorType lhsType = cast<VectorType>(lhs.getType());
+  VectorType rhsType = cast<VectorType>(rhs.getType());
+  int64_t lhsRows = lhsType.getDimSize(0);
+  int64_t lhsColumns = lhsType.getDimSize(1);
+  int64_t rhsColumns = rhsType.getDimSize(1);
+
+  Type flattenedLHSType =
+      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
+  Type flattenedRHSType =
+      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+  Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+      loc,
+      VectorType::get(lhsRows * rhsColumns,
+                      cast<VectorType>(lhs.getType()).getElementType()),
+      lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
+  mul = rew.create<vector::ShapeCastOp>(
+      loc,
+      VectorType::get({lhsRows, rhsColumns},
+                      getElementTypeOrSelf(op.getAcc().getType())),
+      mul);
+
+  // ACC must be C(m, n) or C(n, m).
+  auto accMap = op.getIndexingMapsArray()[2];
+  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+    llvm_unreachable("invalid contraction semantics");
+
+  Value res =
+      isa<IntegerType>(elementType)
+          ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+          : static_cast<Value>(
+                rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+
+  return res;
+}
+
+/// Progressive lowering of TransposeOp.
+/// One:
+///   %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+///   %z = arith.constant dense<0.000000e+00>
+///   %0 = vector.extract %y[0, 0]
+///   %1 = vector.insert %0, %z [0, 0]
+///   ..
+///   %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    Value input = op.getVector();
+    VectorType inputType = op.getSourceVectorType();
+    VectorType resType = op.getResultVectorType();
+
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+
+    // Set up convenience transposition table.
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
+      return failure();
+    }
+
+    Type flattenedType =
+        VectorType::get(resType.getNumElements(), resType.getElementType());
+    auto matrix =
+        rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+    auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+    auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+    Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
+        loc, flattenedType, matrix, rows, columns);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
     RewritePatternSet &patterns) {
   patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
 }
 
+/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void mlir::vector::populateVectorContractToMatrixMultiply(
+    RewritePatternSet &patterns) {
----------------
adam-smnk wrote:

If they remain separate, it'd be better to expose `benefit` as a parameter. Default can remain high.

https://github.com/llvm/llvm-project/pull/144307


More information about the Mlir-commits mailing list