[Mlir-commits] [mlir] [mlir][linalg] Bugfix for `InlineScalarOperands` (PR #111534)
Longsheng Mou
llvmlistbot at llvm.org
Sat Oct 12 19:21:48 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/111534
>From 739352e4527ff41456ea789d40fe39e8d09cd3bc Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 8 Oct 2024 21:58:55 +0800
Subject: [PATCH] [mlir][linalg] Bugfix for `InlineScalarOperands`
This PR fixes a bug where `scalarOperand` is a simple scalar and should be
used directly, rather than accessed via `tensor.extract`.
---
.../Transforms/InlineScalarOperands.cpp | 9 ++++---
.../Linalg/inline-scalar-operands.mlir | 24 +++++++++++++++++++
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 6db51f4b84d112..2a1445fb92fdc6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -78,9 +78,12 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
for (auto idx : indices)
indicesValues.emplace_back(
rewriter.create<arith::ConstantIndexOp>(loc, idx));
- Value extractedValue = rewriter.create<tensor::ExtractOp>(
- loc, opOperand->get(), indicesValues);
- body->getArgument(idx).replaceAllUsesWith(extractedValue);
+ Value scalarValue = opOperand->get();
+ if (isa<RankedTensorType>(scalarValue.getType())) {
+ scalarValue =
+ rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+ }
+ body->getArgument(idx).replaceAllUsesWith(scalarValue);
body->eraseArgument(idx);
}
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
index 93d5b8779c7461..8384b307d2dfbd 100644
--- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
+func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
+ %0 = tensor.empty() : tensor<4xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+ // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+ %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %scalar : tensor<4xf32>, f32)
+ outs(%0 : tensor<4xf32>) {
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
+ %2 = arith.divf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
More information about the Mlir-commits
mailing list