[Mlir-commits] [mlir] [mlir][linalg] Convert linalg.named to linalg.elementwise op. (PR #148424)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 13 05:08:01 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

<details>
<summary>Changes</summary>

Convert linalg.named ops which are elementwise (e.g. add/exp) to `linalg.elementwise`. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.

---
Full diff: https://github.com/llvm/llvm-project/pull/148424.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp (+118) 
- (added) mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir (+38) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+  let summary = "Convert linalg named ops to elementwise where possible";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let summary = "Fold transform, broadcast and other ops into elementwise";
   let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like
 /// `linalg.transform` into elementwise op map.
 void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  NamedToElementwise.cpp
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+      .Case([](SelectOp) { return ElementwiseKind::select; })
+      .Case([](AddOp) { return ElementwiseKind::add; })
+      .Case([](SubOp) { return ElementwiseKind::sub; })
+      .Case([](MulOp) { return ElementwiseKind::mul; })
+      .Case([](DivOp) { return ElementwiseKind::div; })
+      .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+      .Case([](PowFOp) { return ElementwiseKind::powf; })
+      .Case([](ExpOp) { return ElementwiseKind::exp; })
+      .Case([](LogOp) { return ElementwiseKind::log; })
+      .Case([](AbsOp) { return ElementwiseKind::abs; })
+      .Case([](CeilOp) { return ElementwiseKind::ceil; })
+      .Case([](FloorOp) { return ElementwiseKind::floor; })
+      .Case([](NegFOp) { return ElementwiseKind::negf; })
+      .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+      .Case([](RoundOp) { return ElementwiseKind::round; })
+      .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+      .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+      .Case([](SquareOp) { return ElementwiseKind::square; })
+      .Case([](TanhOp) { return ElementwiseKind::tanh; })
+      .Case([](ErfOp) { return ElementwiseKind::erf; })
+      .Default([&](Operation *op) {
+        assert(false && "unexpected op");
+        return ElementwiseKind::sub;
+      });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+  using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(NamedOpTy op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<NamedAttribute> attrs;
+    auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+    attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+    attrs.push_back(
+        rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+                                               op.getDpsInits(), attrs);
+    return success();
+  }
+};
+
+struct LinalgNamedToElementwisePass
+    : public impl::LinalgNamedToElementwisePassBase<
+          LinalgNamedToElementwisePass> {
+  using impl::LinalgNamedToElementwisePassBase<
+      LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateLinalgNamedToElementwisePatterns(patterns);
+
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %add :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK:      linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+  linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C :  memref<16x8xf32>)
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/148424


More information about the Mlir-commits mailing list