[Mlir-commits] [mlir] TosaToLinalg: Prefer to emit identity maps (#386) (PR #123295)
Matthias Gehre
llvmlistbot at llvm.org
Thu Jan 16 23:45:47 PST 2025
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/123295
When deciding whether to emit a map like
`#map = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>` or `#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` for an operand of a `linalg.generic` when lowering element-wise TOSA ops, prefer the latter unless broadcasting of the operand is really needed.
This helps later transformations which often require the affine map to be a projected permuatation.
>From 2bc038ef83edf7531b3e51d7877edc2c0a806ea6 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 18 Oct 2024 08:05:58 -0700
Subject: [PATCH] TosaToLinalg: Prefer to emit identity maps (#386)
When deciding whether to emit a map like
`#map = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>`
or `#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
for an operand of a `linalg.generic` when lowering element-wise TOSA ops,
prefer the latter unless broadcasting of the operand is really needed.
This helps later transformations which often require the affine map to be
a projected permuatation.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 10 ++++++++--
.../TosaToLinalg/tosa-to-linalg.mlir | 20 +++++++++++++++++++
2 files changed, 28 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 9295afd36e3ab1..a183c27abf62ae 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -882,8 +882,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
auto shape = cast<ShapedType>(operand.getType()).getShape();
SmallVector<AffineExpr> affineExprs;
for (auto it : llvm::enumerate(shape)) {
- auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
- : rewriter.getAffineDimExpr(it.index());
+ // Prefer producting identity maps whenever possible (i.e. no broadcasting
+ // needed) because some transforms (like reshape folding)
+ // do not support affine constant exprs.
+ bool requiresBroadcast =
+ (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
+ auto affineExpr = requiresBroadcast
+ ? rewriter.getAffineConstantExpr(0)
+ : rewriter.getAffineDimExpr(it.index());
affineExprs.push_back(affineExpr);
}
return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1d235092b71d55..f36f449da8dbc4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -253,6 +253,26 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t
// -----
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_matching_no_broadcast
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_matching_no_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+
+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32>
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_0]] : tensor<1xf32>) {
+ // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+ // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+ // CHECK: linalg.yield %[[VAL_4]] : f32
+ // CHECK: } -> tensor<1xf32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: return %[[RESULT]] : tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @test_add_1d_matching_static
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
More information about the Mlir-commits
mailing list