[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jan 8 23:46:25 PST 2026
================
@@ -131,17 +132,75 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
}
// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
-// All the variants expressed as pseudo regular expression:
-// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-// have same number of ins/out, so its easy to stamp different versions.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+ std::optional<TypeFn> castTy) {
+ SmallVector<NamedAttribute> castAttrVec;
+ // Only explicitly specify the cast attribute if the cast type exists and is
+ // pointing to unsigned cast (the default is signed cast for
+ // linalg.matmul/linalg.batch_matmul).
+ if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+ castAttrVec = {rewriter.getNamedAttr(
+ "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
- ValueRange{op.getDpsInits()[0]});
+ ValueRange{op.getDpsInits()[0]}, castAttrVec);
return namedOp;
}
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
----------------
banach-space wrote:
Hm, this sounds a bit convoluted.
Shouldn't the return value be std::optional and then:
* if this is a matmul like Op, the return value is set
* otherwise the return value would be empty.
Then, also rename it as `getCastTypeForMatmulLikeOp`.
https://github.com/llvm/llvm-project/pull/174757
More information about the Mlir-commits
mailing list