[Mlir-commits] [mlir] [mlir][linalg][elementwise] Fold transpose into new elementwise (PR #130207)

Javed Absar llvmlistbot at llvm.org
Sat Mar 8 13:44:38 PST 2025


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/130207

>From c8ec0f4c8aaedf7a74da3b6102500c895fe5fa9b Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Thu, 6 Mar 2025 19:18:37 -0500
Subject: [PATCH 1/2] [mlir][linalg][elementwise] Fold transpose into new
 elementwise

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 14 +++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 42 ++++++++++++++++++
 .../test/Dialect/Linalg/elementwise/fold.mlir | 43 +++++++++++++++++++
 3 files changed, 98 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Linalg/elementwise/fold.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f7b1d2c9dfcb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       [{
         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
           attributes, ElementwiseOp::getRegionBuilder());
-      }]>
+      }]>,
+
+     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+          "ElementwiseKindAttr":$kind,
+          "ArrayAttr":$indexingMaps,
+          CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("kind", kind);
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, ElementwiseOp::getRegionBuilder());
+       }]>
     ];
 
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f6b7c32659bb5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    SmallVector<Value> newIns;
+    SmallVector<AffineMap> newMaps;
+    for (OpOperand *operand : op.getDpsInputOperands()) {
+      AffineMap map = op.getMatchingIndexingMap(operand);
+      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+      if (!map.isIdentity() || !transposeOp) {
+        // push in original operand and its map.
+        newIns.push_back(operand->get());
+        newMaps.push_back(map);
+        continue;
+      }
+      newIns.push_back(transposeOp.getInput());
+      // push in transposeOp's inverse permutation map.
+      newMaps.push_back(transposeOp.getMatchingIndexingMap(
+          transposeOp.getDpsInputOperand(0)));
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    newMaps.push_back(op.getIndexingMapsArray().back());
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(
+        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+        rewriter.getAffineMapArrayAttr(newMaps));
+    return success();
+  }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldTranspose>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..7b2ff0b6de12e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK:  func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %empty = tensor.empty() : tensor<8x16x32xf32>
+  %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty :  tensor<8x16x32xf32>) permutation = [1, 0, 2]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK:  func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) ->  tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+  %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty :  tensor<?x?xf32>) permutation = [1, 0]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          ins(%A, %transposed_B : tensor<?x?xf32>,  tensor<?x?xf32>)
+                          outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}

>From e33dd8cbaa8109131e07c138b548c370fdbdee07 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 7 Mar 2025 14:03:36 -0500
Subject: [PATCH 2/2] address reviewers' comments.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 -
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  5 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  4 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 42 ---------
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  1 +
 .../Linalg/Transforms/FoldIntoElementwise.cpp | 89 +++++++++++++++++++
 .../test/Dialect/Linalg/elementwise/fold.mlir |  2 +-
 7 files changed, 100 insertions(+), 44 deletions(-)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f7b1d2c9dfcb3..308e39a9a51e1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -618,7 +618,6 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0..373842c9b03de 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
+  let summary = "Fold transform, broadcast and other ops into elementwise";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8fdcdeff250bb..c302f6d682d69 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that fold operations like
+/// `linalg.transform` into elementwise op map.
+void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f6b7c32659bb5..07b19e5cb1a89 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4286,47 +4285,6 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
-namespace {
-struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
-  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ElementwiseOp op,
-                                PatternRewriter &rewriter) const override {
-    bool changed = false;
-    SmallVector<Value> newIns;
-    SmallVector<AffineMap> newMaps;
-    for (OpOperand *operand : op.getDpsInputOperands()) {
-      AffineMap map = op.getMatchingIndexingMap(operand);
-      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
-
-      if (!map.isIdentity() || !transposeOp) {
-        // push in original operand and its map.
-        newIns.push_back(operand->get());
-        newMaps.push_back(map);
-        continue;
-      }
-      newIns.push_back(transposeOp.getInput());
-      // push in transposeOp's inverse permutation map.
-      newMaps.push_back(transposeOp.getMatchingIndexingMap(
-          transposeOp.getDpsInputOperand(0)));
-      changed = true;
-    }
-    if (!changed)
-      return failure();
-    newMaps.push_back(op.getIndexingMapsArray().back());
-
-    rewriter.replaceOpWithNewOp<ElementwiseOp>(
-        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
-        rewriter.getAffineMapArrayAttr(newMaps));
-    return success();
-  }
-};
-} // namespace
-void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                MLIRContext *context) {
-  results.add<FoldTranspose>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d18b6f8afc43b..881d9fcb4f52e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
+  FoldIntoElementwise.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
new file mode 100644
index 0000000000000..bdd4f6025b051
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -0,0 +1,89 @@
+//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements folding ops such as transpose and broadcast into the
+// affine maps of the elementwise op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-fold-into-elementwise"
+
+namespace {
+struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    SmallVector<Value> newIns;
+    SmallVector<AffineMap> newMaps;
+    for (OpOperand *operand : op.getDpsInputOperands()) {
+      AffineMap map = op.getMatchingIndexingMap(operand);
+      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+      if (!map.isIdentity() || !transposeOp) {
+        // push in original operand and its map.
+        newIns.push_back(operand->get());
+        newMaps.push_back(map);
+        continue;
+      }
+      newIns.push_back(transposeOp.getInput());
+      // push in transposeOp's inverse permutation map.
+      newMaps.push_back(transposeOp.getMatchingIndexingMap(
+          transposeOp.getDpsInputOperand(0)));
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    newMaps.push_back(op.getIndexingMapsArray().back());
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(
+        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+        rewriter.getAffineMapArrayAttr(newMaps));
+    return success();
+  }
+};
+
+struct LinalgFoldIntoElementwisePass
+    : public impl::LinalgFoldIntoElementwisePassBase<
+          LinalgFoldIntoElementwisePass> {
+  using impl::LinalgFoldIntoElementwisePassBase<
+      LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
+
+  void runOnOperation() override {
+    llvm::outs() << "Hellow from fold into elemenwise \n";
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateLinalgFoldIntoElementwisePatterns(patterns);
+
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FoldTransposePattern>(patterns.getContext());
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
index 7b2ff0b6de12e..e83c32fb6a2cf 100644
--- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s
 
 // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>



More information about the Mlir-commits mailing list