[Mlir-commits] [mlir] [mlir] Add helper to check elementwise-mappable ops with tensors and scalars (PR #154872)

Samarth Narang llvmlistbot at llvm.org
Mon Aug 25 10:33:37 PDT 2025


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/154872

>From 23a203c309f9cc1c270255eabc24433acfbc2e78 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Thu, 21 Aug 2025 21:56:25 -0400
Subject: [PATCH 1/6] Fix TODO to use any_of instead of all_of Make check more
 adaptive to include broadcasting of scalars

---
 .../Linalg/Transforms/ElementwiseToLinalg.cpp | 30 +++++++++++++++++--
 1 file changed, 27 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

>From 98850933582261895ddebb409b9d7a4a12843697 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Fri, 22 Aug 2025 09:46:38 -0400
Subject: [PATCH 2/6] Add tests

---
 .../Linalg/convert-elementwise-to-linalg.mlir | 47 +++++++++++++++++++
 1 file changed, 47 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index a6552e0a5264e..ae574b7905be7 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -19,6 +19,53 @@ func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   return %0 : tensor<f32>
 }
 
+// Test a binary elementwise op with a tensor and a scalar operand.
+// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
+//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
+func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %t, %c0 : tensor<?xf32>
+  %init = tensor.empty(%d0) : tensor<?xf32>
+  %splat = linalg.fill ins(%s : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: linalg.generic
+  // CHECK-SAME: iterator_types = ["parallel"]
+  // CHECK-SAME: ins(%[[T]], %{{.*}}
+  %0 = arith.addf %t, %splat : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// Test a comparison op between a tensor and a scalar.
+// CHECK-LABEL: func @cmpf_tensor_scalar
+//  CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
+func.func @cmpf_tensor_scalar(%a: tensor<?xf32>, %s: f32) -> tensor<?xi1> {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %a, %c0 : tensor<?xf32>
+  %initS = tensor.empty(%d0) : tensor<?xf32>
+  %splat = linalg.fill ins(%s : f32) outs(%initS : tensor<?xf32>) -> tensor<?xf32>
+
+  %init = tensor.empty(%d0) : tensor<?xi1>
+  // CHECK: %[[INIT:.*]] = tensor.empty
+  // CHECK: linalg.generic
+  // CHECK-SAME: ins(%[[A]], %{{.*}}
+  %0 = arith.cmpf olt, %a, %splat : tensor<?xf32>
+  return %0 : tensor<?xi1>
+}
+
+// Test a binary elementwise op with a tensor and a zero-dimensional
+// (rank-0) tensor.
+// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor
+//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor<f32>
+func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor<f32>) -> tensor<4xf32> {
+  %c = tensor.extract %r0[] : tensor<f32>
+  %init = tensor.empty() : tensor<4xf32>
+  %splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: linalg.generic
+  // CHECK-SAME: ins(%[[T]], %{{.*}}
+  %0 = arith.addf %t, %splat : tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+
 // -----
 
 // Check indexing maps and iterator types for the rank > 0 case.

>From 48bb04ec175208544b104cfb2faf3b1b324c139b Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Sat, 23 Aug 2025 07:58:48 -0400
Subject: [PATCH 3/6] =?UTF-8?q?Classifies=20scalar-like=20operands=20and?=
 =?UTF-8?q?=20assigns=20them=20a=20rank-aware=20scalar=20map=20(d0,?=
 =?UTF-8?q?=E2=80=A6,dn)=20->=20()=20during=20lowering.?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Linalg/Transforms/ElementwiseToLinalg.cpp | 44 +++++++++++++++----
 1 file changed, 35 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 87e6ff2fa13c6..2cdbf692e0309 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -105,13 +105,39 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
       return rewriter.notifyMatchFailure(
           op, "requires elementwise op on ranked tensors");
 
-    auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
-    SmallVector<AffineMap, 3> indexingMaps(
-        op->getNumResults() + op->getNumOperands(),
-        rewriter.getMultiDimIdentityMap(rank));
-    SmallVector<utils::IteratorType, 6> iteratorTypes(
+    auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+    auto rank = resTy.getRank();
+
+    // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0.
+    AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+                                         /*results=*/{}, rewriter.getContext());
+    AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+    // Create indexing maps: one per operand, one per result.
+    SmallVector<AffineMap, 6> indexingMaps;
+    indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+    for (Value v : op->getOperands()) {
+      Type ty = v.getType();
+      if (isScalarLike(ty))
+        indexingMaps.push_back(scalarMap);
+      else if (auto rt = dyn_cast<RankedTensorType>(ty)) {
+        indexingMaps.push_back(idMap);
+      } else
+        return rewriter.notifyMatchFailure(
+            op,
+            "unsupported operand type (expected scalar-like or ranked tensor)");
+    }
+
+    for (Value r : op->getResults()) {
+      (void)r;
+      indexingMaps.push_back(idMap); // results use identity map.
+    }
+
+    SmallVector<utils::IteratorType, 4> iteratorTypes(
         rank, utils::IteratorType::parallel);
-    auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
+    SmallVector<Value, 2> outputs =
+        getOrCreateOperandsMatchingResultTypes(rewriter, op);
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
         op, /*resultTensorTypes=*/op->getResultTypes(),
         /*inputs=*/op->getOperands(),
@@ -120,14 +146,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
         /*iteratorTypes=*/iteratorTypes,
         /*bodyBuilder=*/
         [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
-          auto resultTypes = llvm::to_vector<6>(
+          SmallVector<Type> resultEltTys = llvm::to_vector<6>(
               llvm::map_range(op->getResultTypes(), [](Type type) {
                 return cast<TensorType>(type).getElementType();
               }));
-          auto *scalarOp =
+          Operation *scalarOp =
               builder.create(loc, op->getName().getIdentifier(),
                              regionArgs.take_front(op->getNumOperands()),
-                             resultTypes, op->getAttrs());
+                             resultEltTys, op->getAttrs());
           linalg::YieldOp::create(builder, loc, scalarOp->getResults());
         });
     return success();

>From 636367607c101994c114c7dceff1f83ad703f3ae Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Sat, 23 Aug 2025 07:59:07 -0400
Subject: [PATCH 4/6] Fix tests

---
 .../Linalg/convert-elementwise-to-linalg.mlir | 105 ++++++++++--------
 1 file changed, 58 insertions(+), 47 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index ae574b7905be7..7aa925ef80517 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -19,53 +19,6 @@ func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   return %0 : tensor<f32>
 }
 
-// Test a binary elementwise op with a tensor and a scalar operand.
-// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
-//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
-func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> {
-  %c0 = arith.constant 0 : index
-  %d0 = tensor.dim %t, %c0 : tensor<?xf32>
-  %init = tensor.empty(%d0) : tensor<?xf32>
-  %splat = linalg.fill ins(%s : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
-  // CHECK: linalg.generic
-  // CHECK-SAME: iterator_types = ["parallel"]
-  // CHECK-SAME: ins(%[[T]], %{{.*}}
-  %0 = arith.addf %t, %splat : tensor<?xf32>
-  return %0 : tensor<?xf32>
-}
-
-// Test a comparison op between a tensor and a scalar.
-// CHECK-LABEL: func @cmpf_tensor_scalar
-//  CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
-func.func @cmpf_tensor_scalar(%a: tensor<?xf32>, %s: f32) -> tensor<?xi1> {
-  %c0 = arith.constant 0 : index
-  %d0 = tensor.dim %a, %c0 : tensor<?xf32>
-  %initS = tensor.empty(%d0) : tensor<?xf32>
-  %splat = linalg.fill ins(%s : f32) outs(%initS : tensor<?xf32>) -> tensor<?xf32>
-
-  %init = tensor.empty(%d0) : tensor<?xi1>
-  // CHECK: %[[INIT:.*]] = tensor.empty
-  // CHECK: linalg.generic
-  // CHECK-SAME: ins(%[[A]], %{{.*}}
-  %0 = arith.cmpf olt, %a, %splat : tensor<?xf32>
-  return %0 : tensor<?xi1>
-}
-
-// Test a binary elementwise op with a tensor and a zero-dimensional
-// (rank-0) tensor.
-// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor
-//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor<f32>
-func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor<f32>) -> tensor<4xf32> {
-  %c = tensor.extract %r0[] : tensor<f32>
-  %init = tensor.empty() : tensor<4xf32>
-  %splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: linalg.generic
-  // CHECK-SAME: ins(%[[T]], %{{.*}}
-  %0 = arith.addf %t, %splat : tensor<4xf32>
-  return %0 : tensor<4xf32>
-}
-
-
 // -----
 
 // Check indexing maps and iterator types for the rank > 0 case.
@@ -155,3 +108,61 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
   return %0 : tensor<4x?x?x8x2x?xi1>
 }
 
+// -----
+
+// Check a mix of scalar and tensor input.
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> 
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> 
+// CHECK-LABEL: func @scalar_plus_tensor
+// CHECK: %[[GEN:.*]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
+// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
+// CHECK:   "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
+// CHECK:   linalg.yield {{.*}} : f32
+// CHECK: } -> tensor<?x?xf32>
+func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = "test.elementwise_mappable"(%arg0, %arg1)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+// This test exercises the case where an elementwise op has two scalar-like
+// operands and one ranked tensor operand. In this example, we chain two
+// `test.elementwise_mappable` calls:
+//   %0 = f(%s1, %t)
+//   %1 = f(%s2, %0)
+// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @scalar_tensor_scalar
+// First generic.
+// CHECK: %[[GEN0:.*]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
+// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
+// CHECK:   %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
+// CHECK:   linalg.yield %[[APPLY0]] : f32
+// CHECK: } -> tensor<?x?xf32>
+
+// Second generic.
+// CHECK: %[[GEN1:.*]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
+// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
+// CHECK:   %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
+// CHECK:   linalg.yield %[[APPLY1]] : f32
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: return %[[GEN1]] : tensor<?x?xf32>
+func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
+  %0 = "test.elementwise_mappable"(%s1, %t)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = "test.elementwise_mappable"(%s2, %0)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}

>From 9009a132bf9687fd764e86f557046db4809451ce Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Mon, 25 Aug 2025 12:10:59 -0400
Subject: [PATCH 5/6] Address review comments

---
 .../Linalg/Transforms/ElementwiseToLinalg.cpp | 53 +++++++---------
 .../Linalg/convert-elementwise-to-linalg.mlir | 60 +++++++++----------
 2 files changed, 52 insertions(+), 61 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 2cdbf692e0309..baf4083d15b0c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,8 @@ namespace mlir {
 
 using namespace mlir;
 
-// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
 static inline bool isScalarLike(Type t) {
-  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
-    return true;
-  if (auto rt = dyn_cast<RankedTensorType>(t))
-    return rt.getRank() == 0; // 0-D tensors are scalar-like
-  return false;
+  return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
 }
 
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
@@ -36,18 +31,12 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   auto types = op->getOperandTypes();
 
   // We want at least one ranked tensor.
-  bool anyRankedTensor = llvm::any_of(
-      types, [](Type type) { return isa<RankedTensorType>(type); });
+  bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
 
   // No invalid operands (i.e., every operand is a ranked tensor or
   // scalar-like).
   bool noneInvalid = llvm::none_of(types, [](Type t) {
-    // Invalid if neither ranked tensor nor scalar-like.
-    if (llvm::isa<RankedTensorType>(t))
-      return false;
-    if (isScalarLike(t))
-      return false;
-    return true; // Could be a memref, unranked tensor, vector, etc.
+    return !(isa<RankedTensorType>(t) || isScalarLike(t));
   });
 
   return anyRankedTensor && noneInvalid;
@@ -108,35 +97,37 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
     auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
     auto rank = resTy.getRank();
 
-    // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0.
+    // Maps: identity for tensors (rank > 0), scalar map for scalars.
     AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
                                          /*results=*/{}, rewriter.getContext());
     AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
 
-    // Create indexing maps: one per operand, one per result.
-    SmallVector<AffineMap, 6> indexingMaps;
-    indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
-
-    for (Value v : op->getOperands()) {
-      Type ty = v.getType();
+    // Match phase.
+    SmallVector<bool> isScalarOperand;
+    isScalarOperand.reserve(op->getNumOperands());
+    for (Type ty : op->getOperandTypes()) {
       if (isScalarLike(ty))
-        indexingMaps.push_back(scalarMap);
-      else if (auto rt = dyn_cast<RankedTensorType>(ty)) {
-        indexingMaps.push_back(idMap);
-      } else
+        isScalarOperand.push_back(true);
+      else if (auto rt = dyn_cast<RankedTensorType>(ty))
+        isScalarOperand.push_back(false);
+      else
         return rewriter.notifyMatchFailure(
             op,
             "unsupported operand type (expected scalar-like or ranked tensor)");
     }
 
-    for (Value r : op->getResults()) {
-      (void)r;
-      indexingMaps.push_back(idMap); // results use identity map.
-    }
+    // Create indexing maps.
+    SmallVector<AffineMap> indexingMaps;
+    indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+    for (bool isScalar : isScalarOperand)
+      indexingMaps.push_back(isScalar ? scalarMap : idMap);
+
+    indexingMaps.append(op->getNumResults(), idMap);
 
-    SmallVector<utils::IteratorType, 4> iteratorTypes(
+    SmallVector<utils::IteratorType> iteratorTypes(
         rank, utils::IteratorType::parallel);
-    SmallVector<Value, 2> outputs =
+    SmallVector<Value> outputs =
         getOrCreateOperandsMatchingResultTypes(rewriter, op);
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
         op, /*resultTensorTypes=*/op->getResultTypes(),
diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index 7aa925ef80517..a01efb3d6c32e 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -114,15 +114,15 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> 
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> 
 // CHECK-LABEL: func @scalar_plus_tensor
-// CHECK: %[[GEN:.*]] = linalg.generic
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]
-// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
-// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
-// CHECK:   "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
-// CHECK:   linalg.yield {{.*}} : f32
-// CHECK: } -> tensor<?x?xf32>
 func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK: %[[GEN:.*]] = linalg.generic
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
+  // CHECK:   "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield {{.*}} : f32
+  // CHECK: } -> tensor<?x?xf32>
   %0 = "test.elementwise_mappable"(%arg0, %arg1)
        : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -137,29 +137,29 @@ func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?x
 // CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
 // CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @scalar_tensor_scalar
-// First generic.
-// CHECK: %[[GEN0:.*]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]
-// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
-// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
-// CHECK:   %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
-// CHECK:   linalg.yield %[[APPLY0]] : f32
-// CHECK: } -> tensor<?x?xf32>
-
-// Second generic.
-// CHECK: %[[GEN1:.*]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]
-// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
-// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
-// CHECK:   %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
-// CHECK:   linalg.yield %[[APPLY1]] : f32
-// CHECK: } -> tensor<?x?xf32>
-// CHECK: return %[[GEN1]] : tensor<?x?xf32>
 func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
+  // First generic.
+  // CHECK: %[[GEN0:.*]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
+  // CHECK:   %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield %[[APPLY0]] : f32
+  // CHECK: } -> tensor<?x?xf32>
+
+  // Second generic.
+  // CHECK: %[[GEN1:.*]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
+  // CHECK:   %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield %[[APPLY1]] : f32
+  // CHECK: } -> tensor<?x?xf32>
+  // CHECK: return %[[GEN1]] : tensor<?x?xf32>
   %0 = "test.elementwise_mappable"(%s1, %t)
        : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
   %1 = "test.elementwise_mappable"(%s2, %0)

>From d3793695e85026f7f756cfc998563b69283ad03e Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Mon, 25 Aug 2025 13:33:06 -0400
Subject: [PATCH 6/6] Add negative test

---
 .../Dialect/Linalg/convert-elementwise-to-linalg.mlir     | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index a01efb3d6c32e..cc7a5469ba73b 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -166,3 +166,11 @@ func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tens
        : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
+
+// ----
+// CHECK-LABEL: func @negative_scalar_only_eltwise
+// CHECK-NOT: linalg
+func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 {
+  %0 = arith.addf %a, %b : f32
+  return %0 : f32
+}



More information about the Mlir-commits mailing list