[Mlir-commits] [mlir] 4b31568 - [mlir][linalg] Bugfix for `InlineScalarOperands` (#111534)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 14 00:38:39 PDT 2024
Author: Longsheng Mou
Date: 2024-10-14T15:38:35+08:00
New Revision: 4b31568e026c844cc577954b050e0f5a7d96bc0c
URL: https://github.com/llvm/llvm-project/commit/4b31568e026c844cc577954b050e0f5a7d96bc0c
DIFF: https://github.com/llvm/llvm-project/commit/4b31568e026c844cc577954b050e0f5a7d96bc0c.diff
LOG: [mlir][linalg] Bugfix for `InlineScalarOperands` (#111534)
This PR fixes a bug where `scalarOperand` is a simple scalar and should
be used directly, rather than accessed via `tensor.extract`. Fixes
#111243.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 6db51f4b84d112..2a1445fb92fdc6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -78,9 +78,12 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
for (auto idx : indices)
indicesValues.emplace_back(
rewriter.create<arith::ConstantIndexOp>(loc, idx));
- Value extractedValue = rewriter.create<tensor::ExtractOp>(
- loc, opOperand->get(), indicesValues);
- body->getArgument(idx).replaceAllUsesWith(extractedValue);
+ Value scalarValue = opOperand->get();
+ if (isa<RankedTensorType>(scalarValue.getType())) {
+ scalarValue =
+ rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+ }
+ body->getArgument(idx).replaceAllUsesWith(scalarValue);
body->eraseArgument(idx);
}
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
index 93d5b8779c7461..8384b307d2dfbd 100644
--- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
+func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
+ %0 = tensor.empty() : tensor<4xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+ // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+ %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %scalar : tensor<4xf32>, f32)
+ outs(%0 : tensor<4xf32>) {
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
+ %2 = arith.divf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
More information about the Mlir-commits
mailing list