[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jan 8 23:46:25 PST 2026


================
@@ -131,17 +132,75 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 }
 
 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
-//  All the variants expressed as pseudo regular expression:
-//      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-//  have same number of ins/out, so its easy to stamp different versions.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
 template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+                                         std::optional<TypeFn> castTy) {
+  SmallVector<NamedAttribute> castAttrVec;
+  // Only explicitly specify the cast attribute if the cast type exists and is
+  // pointing to unsigned cast (the default is signed cast for
+  // linalg.matmul/linalg.batch_matmul).
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+    castAttrVec = {rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]});
+      ValueRange{op.getDpsInits()[0]}, castAttrVec);
   return namedOp;
 }
 
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
----------------
banach-space wrote:

Hm, this sounds a bit convoluted.

Shouldn't the return value be std::optional and then:
* if this is a matmul like Op, the return value is set
* otherwise the return value would  be empty.

Then, also rename it as `getCastTypeForMatmulLikeOp`. 

https://github.com/llvm/llvm-project/pull/174757


More information about the Mlir-commits mailing list