[Mlir-commits] [mlir] [mlir][linalg] Generic to category specialization (PR #184624)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 4 06:32:58 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds initial support for generic to category linalg morphism. Only conversion to contraction op is supported for now.
---
Patch is 29.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184624.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+7)
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+4-2)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+23-11)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+8-2)
- (modified) mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp (+6-6)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+51-34)
- (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+147-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 4948bfffad5e0..5998f736ced34 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -889,6 +889,13 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let skipDefaultBuilders = 1;
let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, regionBuilder);
+ }]>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayAttr":$indexingMaps,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index f48ea9849e237..26638b2a644c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -70,8 +70,10 @@ def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
// Specialization path is not guaranteed.
Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
- "convert linalg.generic to equivalent named ops"> ];
- // TODOs: `generic-to-category`, `category-to-named`
+ "convert linalg.generic to equivalent named ops">,
+ Option<"genericToCategory", "generic-to-category", "bool", /*default=*/"false",
+ "convert linalg.generic to equivalent category ops"> ];
+ // TODOs: `category-to-named`
}
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops">,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fb9cede670801..1e63455fae096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -923,10 +923,15 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp linalgOp);
-/// Create a namedOp from the given GenericOp and replace the GenericOp.
-/// Currently we can specialize only trivial linalg copy operations.
-FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
- GenericOp genericOp);
+struct SpecializationOptions {
+ // Specialize generics to category ops.
+ bool emitCategoryOps = false;
+};
+
+/// Replace the given GenericOp with a namedOp or categoryOp.
+FailureOr<LinalgOp>
+specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ const SpecializationOptions options = {});
/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is either the original subview size when 'useOriginalSubviewSize' is
@@ -1718,17 +1723,24 @@ struct LinalgGeneralizationPattern
};
struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LinalgSpecializationPattern(MLIRContext *context,
+ const SpecializationOptions &options = {},
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit), options(options) {}
FailureOr<GenericOp>
returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
- return specializeGenericOp(rewriter, op);
+ return specializeGenericOp(rewriter, op, options);
}
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
+
+private:
+ SpecializationOptions options;
};
/// Vectorization pattern for memref::CopyOp.
@@ -1938,13 +1950,13 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns to convert linalg.generic ops to named
-/// ops where possible. A linalg.generic can represent wide range and complex
-/// computations for which equivalent linalg named op may not exist e.g.
-/// linalg.generic that takes a tensor and computes a polynomial such as:
+/// or category ops where possible. A linalg.generic can represent wide range
+/// and complex computations for which equivalent linalg named op may not exist
+/// e.g. linalg.generic that takes a tensor and computes a polynomial such as:
/// p(x) = an*x^n + ... + a1x + a0
-/// There is no equivalent named op to convert to. Many such cases exist.
+/// There is no equivalent ops to convert to. Many such cases exist.
void populateLinalgGenericOpsSpecializationPatterns(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, const SpecializationOptions &options = {});
/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
/// to equivalent `linalg.elementwise`.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bfc03cc7436df..67d7406987569 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -200,7 +200,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map);
});
- state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ if (none_of(attributes, [](NamedAttribute attr) {
+ return attr.getName() == "indexing_maps";
+ }))
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
}
@@ -217,7 +220,10 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map);
});
- state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ if (none_of(attributes, [](NamedAttribute attr) {
+ return attr.getName() == "indexing_maps";
+ }))
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
index f261ccb1415fe..17416b42c47ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -44,16 +44,16 @@ void LinalgMorphOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
// Lowering paths (named -> category -> generic)
- if (namedToCategory) {
+ if (namedToCategory)
populateLinalgNamedToElementwisePatterns(patterns);
- }
- if (namedToGeneric || categoryToGeneric) {
+ if (namedToGeneric || categoryToGeneric)
populateLinalgNamedOpsGeneralizationPatterns(patterns);
- }
// Lifting paths (named <- category <- generic)
- if (genericToNamed) {
- populateLinalgGenericOpsSpecializationPatterns(patterns);
+ if (genericToNamed || genericToCategory) {
+ SpecializationOptions opts;
+ opts.emitCategoryOps = genericToCategory;
+ populateLinalgGenericOpsSpecializationPatterns(patterns, opts);
}
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d74335e3c08c9..24a02b48427ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -140,21 +140,23 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
std::optional<TypeFn> castTy) {
- SmallVector<NamedAttribute> castAttrVec;
+ SmallVector<NamedAttribute> attributes;
// Only explicitly specify the cast attribute for unsigned cast; signed is
// the default for linalg.matmul/linalg.batch_matmul.
- if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
- castAttrVec = {rewriter.getNamedAttr(
- "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+ if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
+ auto castAttr = rewriter.getNamedAttr(
+ "cast", TypeFnAttr::get(rewriter.getContext(), *castTy));
+ attributes.push_back(castAttr);
+ }
- ArrayAttr indexingMaps = op.getIndexingMaps();
+ // Set the original generic's maps to preserve transposed operand semantics.
+ auto indexingMapsAttr =
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMapsAttr());
+ attributes.push_back(indexingMapsAttr);
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
- ValueRange{op.getDpsInits()[0]}, castAttrVec);
-
- // Set the original generic's maps to preserve transposed operand semantics.
- namedOp->setAttr("indexing_maps", indexingMaps);
+ ValueRange{op.getDpsInits()[0]}, attributes);
return namedOp;
}
@@ -210,7 +212,8 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
- GenericOp genericOp) {
+ GenericOp genericOp,
+ bool emitCategoryOp) {
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
return failure();
@@ -220,6 +223,28 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();
+ // Only mul+add contraction is supported.
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
+ }))
+ return failure();
+
+ // Determine the cast type for the named matmul op, or bail out if casts
+ // cannot be represented by the named op.
+ std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
+ if (!castTy)
+ return rewriter.notifyMatchFailure(
+ genericOp, "contains invalid cast ops for the named matmul op");
+
+ // In case of category op, wider range of representation is supported.
+ if (emitCategoryOp)
+ return replaceWithMatmulVariant<ContractOp>(rewriter, genericOp, castTy);
+
+ // Further checks for named variants.
+ //
// Linalg generic contraction can be across multiple axis e.g.
// ```
// linalg.generic
@@ -246,14 +271,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();
- if (!mlir::linalg::detail::isContractionBody(
- *genericOp.getBlock(), [](Operation *first, Operation *second) {
- return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
- (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
- (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
- }))
- return failure();
-
// Check rank of operands
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
@@ -292,13 +309,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
return failure();
- // Determine the cast type for the named matmul op, or bail out if casts
- // cannot be represented by the named op.
- std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
- if (!castTy)
- return rewriter.notifyMatchFailure(
- genericOp, "contains invalid cast ops for the named matmul op");
-
/// Codegen the different matmul variants.
if (numOfBatchDims) {
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
@@ -406,8 +416,20 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
// Categorize linalg generic to named op where possible.
//===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
- GenericOp genericOp) {
+FailureOr<LinalgOp>
+mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ const SpecializationOptions options) {
+ // Contraction - e.g. matmul
+ if (isaContractionOpInterface(genericOp)) {
+ return specializeLinalgContractions(rewriter, genericOp,
+ options.emitCategoryOps);
+ }
+
+ // Early exit in case of category specialization.
+ // TODO: Remove when all variants account for both named and category.
+ if (options.emitCategoryOps)
+ return failure();
+
// Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
@@ -476,11 +498,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
- // Contraction - e.g. matmul
- if (isaContractionOpInterface(genericOp)) {
- return specializeLinalgContractions(rewriter, genericOp);
- }
-
// Convolution - e.g. *conv/pooling*
if (isaConvolutionOpInterface(genericOp)) {
return specializeLinalgConvolutions(rewriter, genericOp);
@@ -509,6 +526,6 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() {
}
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
- RewritePatternSet &patterns) {
- patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+ RewritePatternSet &patterns, const SpecializationOptions &options) {
+ patterns.add<LinalgSpecializationPattern>(patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87218844c5c39..3f62f27d33dba 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=generic-to-named | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=generic-to-category | FileCheck %s --check-prefix=CATEGORY
#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
@@ -17,6 +18,10 @@ func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tenso
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// Not supported yet.
+// CATEGORY-LABEL: unary_op_exp
+// CATEGORY: linalg.generic
+
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
@@ -36,6 +41,10 @@ func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<
// CHECK-NOT: linalg.generic
// CHECK: linalg.div ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// Not supported yet.
+// CATEGORY-LABEL: binary_op_div
+// CATEGORY: linalg.generic
+
// -----
///----------------------------------------------------------------------------------------
@@ -62,6 +71,17 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CATEGORY-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CATEGORY-LABEL: op_matmul
+// CATEGORY-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
// Cast-auditing tests: ensure we only specialize when the cast semantics can
// be expressed by linalg.matmul, and use the cast attribute when needed.
@@ -84,6 +104,11 @@ func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi32>,
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CATEGORY-LABEL: op_matmul_unsigned_cast
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: {cast = #linalg.type_fn<cast_unsigned>}
+
// Ensures truncation rounding is tolerated with unsigned cases.
// Note: We only consider casts as conflicting if they have different
// signedness behaviours, and then we do not specialize if they do
@@ -110,6 +135,11 @@ func.func @op_matmul_unsigned_cast_and_truncate(%A: tensor<16x8xi16>, %B: tensor
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CATEGORY-LABEL: op_matmul_unsigned_cast_and_truncate
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: {cast = #linalg.type_fn<cast_unsigned>}
+
// Signed casts are the default, no cast attribute is required.
func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
%Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
@@ -131,6 +161,11 @@ func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
// CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
// CHECK: linalg.matmul
+// CATEGORY-LABEL: op_matmul_signed_cast
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-NOT: {cast =
+
// Mixed signed/unsigned inputs cannot be encoded with a single cast attribute.
func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
%Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
@@ -151,6 +186,10 @@ func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi
// CHECK: linalg.generic
// CHECK-NOT: linalg.matmul
+// CATEGORY-LABEL: negative_op_matmul_mixed_cast
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
+
// Output-side casts are not representable by the named matmul ops.
func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32xi32>,
%Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
@@ -171,6 +210,10 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
// CHECK: linalg.generic
// CHECK-NOT: linalg.matmul
+// CATEGORY-LABEL: negative_op_matmul_output_cast
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
+
// Bitcasts are not modeled by the cast attribute, but should not block
// specialization.
// NOTE: Bitcasts are not preserved by the ma...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184624
More information about the Mlir-commits
mailing list