[Mlir-commits] [mlir] [MLIR][Linalg] Remove matmul_transpose variants (PR #147961)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Jul 10 09:59:20 PDT 2025
================
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = matmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ newLHS = matmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
----------------
qedawkins wrote:
Pattern matching can't be replaced as easily, but we can add bespoke C++ for it like `matchMatmulTransposeB`
https://github.com/llvm/llvm-project/pull/147961
More information about the Mlir-commits
mailing list