[Mlir-commits] [mlir] ecf4d99 - [mlir][linalg][elementwise] Fold transpose into new elementwise (#130207)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 12 16:04:48 PDT 2025
Author: Javed Absar
Date: 2025-03-12T23:04:44Z
New Revision: ecf4d995f689a30bb5a2b79c27998a7e7a0a08b0
URL: https://github.com/llvm/llvm-project/commit/ecf4d995f689a30bb5a2b79c27998a7e7a0a08b0
DIFF: https://github.com/llvm/llvm-project/commit/ecf4d995f689a30bb5a2b79c27998a7e7a0a08b0.diff
LOG: [mlir][linalg][elementwise] Fold transpose into new elementwise (#130207)
Fold transpose into new elementwise Op which has affine-map attached.
Will add broadcast folding in next diff.
Added:
mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
mlir/test/Dialect/Linalg/elementwise/fold.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..308e39a9a51e1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -601,7 +601,18 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
- }]>
+ }]>,
+
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ElementwiseKindAttr":$kind,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("kind", kind);
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElementwiseOp::getRegionBuilder());
+ }]>
];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0..373842c9b03de 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
+ let summary = "Fold transform, broadcast and other ops into elementwise";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8fdcdeff250bb..c302f6d682d69 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that fold operations like
+/// `linalg.transform` into elementwise op map.
+void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d18b6f8afc43b..881d9fcb4f52e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
+ FoldIntoElementwise.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
new file mode 100644
index 0000000000000..bdd4f6025b051
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -0,0 +1,89 @@
+//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements folding ops such as transpose and broadcast into the
+// affine maps of the elementwise op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-fold-into-elementwise"
+
+namespace {
+struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
+ using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ElementwiseOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ SmallVector<Value> newIns;
+ SmallVector<AffineMap> newMaps;
+ for (OpOperand *operand : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(operand);
+ auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+ if (!map.isIdentity() || !transposeOp) {
+ // push in original operand and its map.
+ newIns.push_back(operand->get());
+ newMaps.push_back(map);
+ continue;
+ }
+ newIns.push_back(transposeOp.getInput());
+ // push in transposeOp's inverse permutation map.
+ newMaps.push_back(transposeOp.getMatchingIndexingMap(
+ transposeOp.getDpsInputOperand(0)));
+ changed = true;
+ }
+ if (!changed)
+ return failure();
+ newMaps.push_back(op.getIndexingMapsArray().back());
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(
+ op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+ rewriter.getAffineMapArrayAttr(newMaps));
+ return success();
+ }
+};
+
+struct LinalgFoldIntoElementwisePass
+ : public impl::LinalgFoldIntoElementwisePassBase<
+ LinalgFoldIntoElementwisePass> {
+ using impl::LinalgFoldIntoElementwisePassBase<
+ LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
+
+ void runOnOperation() override {
+ llvm::outs() << "Hellow from fold into elemenwise \n";
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateLinalgFoldIntoElementwisePatterns(patterns);
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldTransposePattern>(patterns.getContext());
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
new file mode 100644
index 0000000000000..e83c32fb6a2cf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//
+// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty = tensor.empty() : tensor<8x16x32xf32>
+ %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+ %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list