[Mlir-commits] [mlir] [mlir][linalg] Bugfix for `InlineScalarOperands` (PR #111534)

Longsheng Mou llvmlistbot at llvm.org
Sat Oct 12 19:21:48 PDT 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/111534

>From 739352e4527ff41456ea789d40fe39e8d09cd3bc Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 8 Oct 2024 21:58:55 +0800
Subject: [PATCH] [mlir][linalg] Bugfix for `InlineScalarOperands`

This PR fixes a bug where `scalarOperand` is a simple scalar and should be
used directly, rather than accessed via `tensor.extract`.
---
 .../Transforms/InlineScalarOperands.cpp       |  9 ++++---
 .../Linalg/inline-scalar-operands.mlir        | 24 +++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 6db51f4b84d112..2a1445fb92fdc6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -78,9 +78,12 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
       for (auto idx : indices)
         indicesValues.emplace_back(
             rewriter.create<arith::ConstantIndexOp>(loc, idx));
-      Value extractedValue = rewriter.create<tensor::ExtractOp>(
-          loc, opOperand->get(), indicesValues);
-      body->getArgument(idx).replaceAllUsesWith(extractedValue);
+      Value scalarValue = opOperand->get();
+      if (isa<RankedTensorType>(scalarValue.getType())) {
+        scalarValue =
+            rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+      }
+      body->getArgument(idx).replaceAllUsesWith(scalarValue);
       body->eraseArgument(idx);
     }
 
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
index 93d5b8779c7461..8384b307d2dfbd 100644
--- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
     } -> tensor<4xf32>
   return %1 : tensor<4xf32>
 }
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
+func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
+    %0 = tensor.empty() : tensor<4xf32>
+    // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+    // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+    %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+                         iterator_types = ["parallel"]}
+                         ins(%arg0, %scalar : tensor<4xf32>, f32)
+                         outs(%0 : tensor<4xf32>) {
+    // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
+      %2 = arith.divf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}



More information about the Mlir-commits mailing list