[Mlir-commits] [mlir] ceda56b - [mlir][linalg] Morphism across linalg -- named, category and generic ops. (#148424)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 7 04:36:51 PDT 2025
Author: Javed Absar
Date: 2025-08-07T12:36:47+01:00
New Revision: ceda56be7f03a790ea777e8b98b419209c3bfa49
URL: https://github.com/llvm/llvm-project/commit/ceda56be7f03a790ea777e8b98b419209c3bfa49
DIFF: https://github.com/llvm/llvm-project/commit/ceda56be7f03a790ea777e8b98b419209c3bfa49.diff
LOG: [mlir][linalg] Morphism across linalg -- named, category and generic ops. (#148424)
Adds `linalg-morph-ops` pass to convert an op from one representation to another:
named-op <--> category_op (elementwise, contraction, ..) <--> generic
e.g.
```mlir
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
```
After `mlir-opt -linalg-morph-ops=named-to-category ..`
```mlir
%0 = linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%arg0 : tensor<16x8xf32> ..
Note: this is generalization of
`--linalg-generalize-named-ops` is the path `named-op --> generic-op`
`--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
email: quic_mabsar at quicinc.com
Added:
mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f23662930accc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -89,6 +89,45 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
];
}
+def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
+ let summary = "Convert named op to category ops or generic and vice-versa";
+
+ let description = [{
+ Convert a linalg op from one representation to another equivalent.
+ For example, a linalg named op `linalg.add` can also be written as an
+ category op `linalg.elementwise`, and can also be re-written as
+ a `linalg.generic`, giving the morphism:
+
+ named-op <--> category_op (elementwise, contraction, ..) <--> generic
+
+ Note that the set of `linalg.generic` subsumes named and category ops
+ and therefore not all `linalg.genric` can be converted to named or
+ category op. Similarly, catgory ops subsume named ops.
+
+ Note:
+ Legacy converters:
+ `--linalg-generalize-named-ops` is the path `named-op --> generic-op`
+ `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+
+ let options = [
+ // named-op <--> category <--> generic
+
+ // Lowering options
+ Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
+ "convert named ops to category op e.g. `linalg.elementwise`">,
+ Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
+ "convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
+ Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
+ "convert named ops e.g. `linalg.add` to `linalg.generic`">,
+
+ // Lifting options
+ // TODOs: `generic-to-category`, `category-to-named`
+ Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
+ "convert linalg.generic to equivalent named ops"> ];
+}
+
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let summary = "Convert named ops into generic ops";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..1e5b5d46de55f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 70f846e5bbd20..6ec2e9fd0be7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
new file mode 100644
index 0000000000000..f261ccb1415fe
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -0,0 +1,62 @@
+//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements conversions between linalg ops:
+// named <--> category (elementwise, contraction, ..) <--> generic.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMORPHOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-morphism"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgMorphOpsPass
+ : public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
+
+ using impl::LinalgMorphOpsPassBase<
+ LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgMorphOpsPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+
+ // Lowering paths (named -> category -> generic)
+ if (namedToCategory) {
+ populateLinalgNamedToElementwisePatterns(patterns);
+ }
+ if (namedToGeneric || categoryToGeneric) {
+ populateLinalgNamedOpsGeneralizationPatterns(patterns);
+ }
+
+ // Lifting paths (named <- category <- generic)
+ if (genericToNamed) {
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ }
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..00a076b6e9746
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,98 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unhandled case in named to elementwise");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
new file mode 100644
index 0000000000000..2332b287ace8d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %add : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>)
+//
+func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %sub : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>)
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<4x8x16xf32>
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<select>
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+// CHECK-SAME: outs(%[[E]] : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+//
+func.func @ternary_select(%A: tensor<4x8x16xi1>, %B: tensor<4x8x16xf32>, %C: tensor<4x8x16xf32>)
+ -> tensor<4x8x16xf32> {
+ %empty = tensor.empty() : tensor<4x8x16xf32>
+ %select = linalg.select
+ ins(%A, %B, %C : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+ outs(%empty: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %select : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
new file mode 100644
index 0000000000000..00602c4a36010
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
@@ -0,0 +1,15 @@
+// Forward path `named -> category -> generic`
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s --check-prefix=NAMED_TO_CATEGORY
+
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | \
+// RUN: mlir-opt %s -linalg-morph-ops=category-to-generic | FileCheck %s --check-prefix=CATEGORY_TO_GENERIC
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+// NAMED_TO_CATEGORY: linalg.elementwise
+// NAMED_TO_CATEGORY-NOT: linalg.exp
+
+// CATEGORY_TO_GENERIC: linalg.generic
+// CATEGORY_TO_GENERIC-NOT: linalg.elementwise
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
new file mode 100644
index 0000000000000..ab50a44a37067
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt %s -linalg-morph-ops=generic-to-named | \
+// RUN: FileCheck %s --check-prefix=ROUND_TRIP
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// NAMED_TO_GENERIC: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.exp
+
+// ROUND_TRIP: linalg.exp
+// ROUND_TRIP-NOT: linalg.generic
More information about the Mlir-commits
mailing list