[Mlir-commits] [mlir] [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect (PR #144307)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Jul 15 05:24:10 PDT 2025
================
@@ -2035,13 +2000,217 @@ struct VectorScalableStepOpLowering
}
};
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+/// %flattened_a = vector.shape_cast %a
+/// %flattened_b = vector.shape_cast %b
+/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+/// %d = vector.shape_cast %%flattened_d
+/// %e = add %c, %d
+/// ```
+/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when vectorContractLowering is set to Matmul and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+ : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
+public:
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ ContractionOpToMatmulOpLowering(
+ vector::VectorContractLowering vectorContractLowering,
+ MLIRContext *context, PatternBenefit benefit = 100)
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+/// %mta = maybe_transpose
+/// %mtb = maybe_transpose
+/// %flattened_a = vector.shape_cast %mta
+/// %flattened_b = vector.shape_cast %mtb
+/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
+/// %mtd = vector.shape_cast %flattened_d
+/// %d = maybe_untranspose %mtd
+/// %e = add %c, %d
+/// ```
+//
+/// This only kicks in when vectorContractLowering is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rew) const {
+ // TODO: Support vector.mask.
+ if (maskOp)
+ return failure();
+
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ if (!isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]))
+ return failure();
+
+ Type opResType = op.getType();
+ VectorType vecType = dyn_cast<VectorType>(opResType);
+ if (vecType && vecType.isScalable()) {
+ // Note - this is sufficient to reject all cases with scalable vectors.
+ return failure();
+ }
+
+ Type elementType = op.getLhsType().getElementType();
+ if (!elementType.isIntOrFloat())
+ return failure();
+
+ Type dstElementType = vecType ? vecType.getElementType() : opResType;
+ if (elementType != dstElementType)
+ return failure();
+
+ // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+ // Bail out if the contraction cannot be put in this form.
+ MLIRContext *ctx = op.getContext();
+ Location loc = op.getLoc();
+ AffineExpr m, n, k;
+ bindDims(rew.getContext(), m, n, k);
+ // LHS must be A(m, k) or A(k, m).
+ Value lhs = op.getLhs();
+ auto lhsMap = op.getIndexingMapsArray()[0];
+ if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+ lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+ else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+ return failure();
+
+ // RHS must be B(k, n) or B(n, k).
+ Value rhs = op.getRhs();
+ auto rhsMap = op.getIndexingMapsArray()[1];
+ if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+ rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+ else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+ return failure();
+
+ // At this point lhs and rhs are in row-major.
+ VectorType lhsType = cast<VectorType>(lhs.getType());
+ VectorType rhsType = cast<VectorType>(rhs.getType());
+ int64_t lhsRows = lhsType.getDimSize(0);
+ int64_t lhsColumns = lhsType.getDimSize(1);
+ int64_t rhsColumns = rhsType.getDimSize(1);
+
+ Type flattenedLHSType =
+ VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+ lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
+ Type flattenedRHSType =
+ VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+ rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+ Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+ loc,
+ VectorType::get(lhsRows * rhsColumns,
+ cast<VectorType>(lhs.getType()).getElementType()),
+ lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
+ mul = rew.create<vector::ShapeCastOp>(
+ loc,
+ VectorType::get({lhsRows, rhsColumns},
+ getElementTypeOrSelf(op.getAcc().getType())),
+ mul);
+
+ // ACC must be C(m, n) or C(n, m).
+ auto accMap = op.getIndexingMapsArray()[2];
+ if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+ mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+ else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+ llvm_unreachable("invalid contraction semantics");
+
+ Value res =
+ isa<IntegerType>(elementType)
+ ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+ : static_cast<Value>(
+ rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+
+ return res;
+}
+
+/// Progressive lowering of TransposeOp.
+/// One:
+/// %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+/// %z = arith.constant dense<0.000000e+00>
+/// %0 = vector.extract %y[0, 0]
+/// %1 = vector.insert %0, %z [0, 0]
+/// ..
+/// %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ Value input = op.getVector();
+ VectorType inputType = op.getSourceVectorType();
+ VectorType resType = op.getResultVectorType();
+
+ if (inputType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "This lowering does not support scalable vectors");
+
+ // Set up convenience transposition table.
+ ArrayRef<int64_t> transp = op.getPermutation();
+
+ if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
+ return failure();
+ }
+
+ Type flattenedType =
+ VectorType::get(resType.getNumElements(), resType.getElementType());
+ auto matrix =
+ rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+ auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+ auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+ Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
+ loc, flattenedType, matrix, rows, columns);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
RewritePatternSet &patterns) {
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
}
+/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
+///
+/// Given the high benefit, this will be prioriotised over other
+/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
+/// only run this registration conditionally.
+void mlir::vector::populateVectorContractToMatrixMultiply(
+ RewritePatternSet &patterns) {
----------------
adam-smnk wrote:
If they remain separate, it'd be better to expose `benefit` as a parameter. Default can remain high.
https://github.com/llvm/llvm-project/pull/144307
More information about the Mlir-commits
mailing list