[Mlir-commits] [mlir] ceda56b - [mlir][linalg] Morphism across linalg -- named, category and generic ops. (#148424)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 7 04:36:51 PDT 2025


Author: Javed Absar
Date: 2025-08-07T12:36:47+01:00
New Revision: ceda56be7f03a790ea777e8b98b419209c3bfa49

URL: https://github.com/llvm/llvm-project/commit/ceda56be7f03a790ea777e8b98b419209c3bfa49
DIFF: https://github.com/llvm/llvm-project/commit/ceda56be7f03a790ea777e8b98b419209c3bfa49.diff

LOG: [mlir][linalg] Morphism across linalg -- named, category and generic ops. (#148424)

Adds `linalg-morph-ops` pass to convert an op from one representation to another: 
   named-op <--> category_op (elementwise, contraction, ..) <--> generic
e.g.
```mlir
  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
```
After `mlir-opt -linalg-morph-ops=named-to-category ..`

```mlir
  %0 = linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%arg0 : tensor<16x8xf32> ..

Note: this is generalization of 
`--linalg-generalize-named-ops` is the path `named-op --> generic-op`
`--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`

email: quic_mabsar at quicinc.com

Added: 
    mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
    mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
    mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
    mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f23662930accc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -89,6 +89,45 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
   ];
 }
 
+def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
+  let summary = "Convert named op to category ops or generic and vice-versa";
+
+  let description = [{
+    Convert a linalg op from one representation to another equivalent.
+    For example, a linalg named op `linalg.add` can also be written as an
+    category op `linalg.elementwise`, and can also be re-written as
+    a `linalg.generic`, giving the morphism:
+
+      named-op <--> category_op (elementwise, contraction, ..) <--> generic
+
+    Note that the set of `linalg.generic` subsumes named and category ops
+    and therefore not all `linalg.genric` can be converted to  named or
+    category op. Similarly, catgory ops subsume named ops.
+
+    Note:
+     Legacy converters:
+     `--linalg-generalize-named-ops` is the path `named-op --> generic-op`
+     `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
+  }];
+  let dependentDialects = ["linalg::LinalgDialect"];
+
+  let options = [
+    // named-op <--> category <--> generic
+
+    // Lowering options
+    Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
+           "convert named ops to category op e.g. `linalg.elementwise`">,
+    Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
+           "convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
+    Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
+           "convert named ops e.g. `linalg.add` to `linalg.generic`">,
+
+    // Lifting options
+    //  TODOs: `generic-to-category`, `category-to-named`
+    Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
+           "convert linalg.generic to equivalent named ops"> ];
+}
+
 def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let summary = "Convert named ops into generic ops";
   let dependentDialects = ["linalg::LinalgDialect"];

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..1e5b5d46de55f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like
 /// `linalg.transform` into elementwise op map.
 void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 70f846e5bbd20..6ec2e9fd0be7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MorphOps.cpp
   TransposeMatmul.cpp
   ShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  NamedToElementwise.cpp
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
new file mode 100644
index 0000000000000..f261ccb1415fe
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -0,0 +1,62 @@
+//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements conversions between linalg ops:
+//    named <--> category (elementwise, contraction, ..) <--> generic.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMORPHOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-morphism"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgMorphOpsPass
+    : public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
+
+  using impl::LinalgMorphOpsPassBase<
+      LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
+
+  void runOnOperation() override;
+};
+
+void LinalgMorphOpsPass::runOnOperation() {
+
+  RewritePatternSet patterns(&getContext());
+
+  // Lowering paths (named -> category -> generic)
+  if (namedToCategory) {
+    populateLinalgNamedToElementwisePatterns(patterns);
+  }
+  if (namedToGeneric || categoryToGeneric) {
+    populateLinalgNamedOpsGeneralizationPatterns(patterns);
+  }
+
+  // Lifting paths (named <- category <- generic)
+  if (genericToNamed) {
+    populateLinalgGenericOpsSpecializationPatterns(patterns);
+  }
+
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+    signalPassFailure();
+}
+} // namespace

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..00a076b6e9746
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,98 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+      .Case([](SelectOp) { return ElementwiseKind::select; })
+      .Case([](AddOp) { return ElementwiseKind::add; })
+      .Case([](SubOp) { return ElementwiseKind::sub; })
+      .Case([](MulOp) { return ElementwiseKind::mul; })
+      .Case([](DivOp) { return ElementwiseKind::div; })
+      .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+      .Case([](PowFOp) { return ElementwiseKind::powf; })
+      .Case([](ExpOp) { return ElementwiseKind::exp; })
+      .Case([](LogOp) { return ElementwiseKind::log; })
+      .Case([](AbsOp) { return ElementwiseKind::abs; })
+      .Case([](CeilOp) { return ElementwiseKind::ceil; })
+      .Case([](FloorOp) { return ElementwiseKind::floor; })
+      .Case([](NegFOp) { return ElementwiseKind::negf; })
+      .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+      .Case([](RoundOp) { return ElementwiseKind::round; })
+      .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+      .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+      .Case([](SquareOp) { return ElementwiseKind::square; })
+      .Case([](TanhOp) { return ElementwiseKind::tanh; })
+      .Case([](ErfOp) { return ElementwiseKind::erf; })
+      .Default([&](Operation *op) {
+        llvm_unreachable("unhandled case in named to elementwise");
+        return ElementwiseKind::sub;
+      });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+  using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(NamedOpTy op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<NamedAttribute> attrs;
+    auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+    attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+    attrs.push_back(
+        rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+                                               op.getDpsInits(), attrs);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
new file mode 100644
index 0000000000000..2332b287ace8d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %add :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>)
+//
+func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+  %sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %sub : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>)
+// CHECK:   %[[E:.+]] =  tensor.empty() : tensor<4x8x16xf32>
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<select>
+// CHECK-SAME:       ins(%[[A]], %[[B]], %[[C]] : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+// CHECK-SAME:       outs(%[[E]] : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+//
+func.func @ternary_select(%A: tensor<4x8x16xi1>, %B: tensor<4x8x16xf32>, %C: tensor<4x8x16xf32>)
+             -> tensor<4x8x16xf32> {
+  %empty = tensor.empty() : tensor<4x8x16xf32>
+  %select = linalg.select
+              ins(%A, %B, %C : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+              outs(%empty: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %select : tensor<4x8x16xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
new file mode 100644
index 0000000000000..00602c4a36010
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
@@ -0,0 +1,15 @@
+// Forward path `named -> category -> generic`
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s  --check-prefix=NAMED_TO_CATEGORY
+
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category |  \
+// RUN:   mlir-opt %s -linalg-morph-ops=category-to-generic | FileCheck %s  --check-prefix=CATEGORY_TO_GENERIC
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+// NAMED_TO_CATEGORY: linalg.elementwise
+// NAMED_TO_CATEGORY-NOT: linalg.exp
+
+// CATEGORY_TO_GENERIC: linalg.generic
+// CATEGORY_TO_GENERIC-NOT: linalg.elementwise

diff  --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
new file mode 100644
index 0000000000000..ab50a44a37067
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  FileCheck %s  --check-prefix=NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  mlir-opt %s -linalg-morph-ops=generic-to-named | \
+// RUN:   FileCheck %s  --check-prefix=ROUND_TRIP
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+
+// NAMED_TO_GENERIC: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.exp
+
+// ROUND_TRIP: linalg.exp
+// ROUND_TRIP-NOT: linalg.generic


        


More information about the Mlir-commits mailing list