[Mlir-commits] [mlir] [mlir][linalg] Convert linalg.named to linalg.elementwise op. (PR #148424)
Javed Absar
llvmlistbot at llvm.org
Sun Jul 13 05:07:23 PDT 2025
https://github.com/javedabsar1 created https://github.com/llvm/llvm-project/pull/148424
Convert linalg.named ops which are elementwise (e.g. add/exp) to `linalg.elementwise`. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.
>From 56c84313ce6547c172964e8af27afbeba6aeb0a3 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 12 Jul 2025 09:59:56 -0400
Subject: [PATCH] [mlir][linalg] Convert linalg.named to linalg.elementwise op.
Convert linalg.named ops which are elementwise (e.g. add/exp)
to `linalg.elementwise`. Currently, named ops have to drop to
linalg.generic (--generalize-named-ops), where one figures
out which generic are elementwise. Also, folding of broadcast
or transpose can occur then only at generic level. Instead,
with this rewrite, these can happen now at linalg.elementwise.
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 5 +
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/NamedToElementwise.cpp | 118 ++++++++++++++++++
.../elementwise/named_to_elementwise.mlir | 38 ++++++
5 files changed, 166 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
create mode 100644 mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+ let summary = "Convert linalg named ops to elementwise where possible";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transform, broadcast and other ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ assert(false && "unexpected op");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+
+struct LinalgNamedToElementwisePass
+ : public impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass> {
+ using impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateLinalgNamedToElementwisePatterns(patterns);
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %add : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK: linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+ linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C : memref<16x8xf32>)
+ return
+}
More information about the Mlir-commits
mailing list