[Mlir-commits] [mlir] [mlir] Add helper to check elementwise-mappable ops with tensors and scalars (PR #154872)

Samarth Narang llvmlistbot at llvm.org
Fri Aug 22 06:46:59 PDT 2025


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/154872

>From 8a18d012cdef703057a1032f3d2ea706c52f7987 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Thu, 21 Aug 2025 21:56:25 -0400
Subject: [PATCH 1/2] Fix TODO to use any_of instead of all_of Make check more
 adaptive to include broadcasting of scalars

---
 .../Linalg/Transforms/ElementwiseToLinalg.cpp | 30 +++++++++++++++++--
 1 file changed, 27 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

>From fe6a39068c5dc38b73280e1f3be17796511d991a Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at umass.edu>
Date: Fri, 22 Aug 2025 09:46:38 -0400
Subject: [PATCH 2/2] Add tests

---
 .../Linalg/convert-elementwise-to-linalg.mlir | 47 +++++++++++++++++++
 1 file changed, 47 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index a6552e0a5264e..ae574b7905be7 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -19,6 +19,53 @@ func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   return %0 : tensor<f32>
 }
 
+// Test a binary elementwise op with a tensor and a scalar operand.
+// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
+//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
+func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %t, %c0 : tensor<?xf32>
+  %init = tensor.empty(%d0) : tensor<?xf32>
+  %splat = linalg.fill ins(%s : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
+  // CHECK: linalg.generic
+  // CHECK-SAME: iterator_types = ["parallel"]
+  // CHECK-SAME: ins(%[[T]], %{{.*}}
+  %0 = arith.addf %t, %splat : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// Test a comparison op between a tensor and a scalar.
+// CHECK-LABEL: func @cmpf_tensor_scalar
+//  CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
+func.func @cmpf_tensor_scalar(%a: tensor<?xf32>, %s: f32) -> tensor<?xi1> {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %a, %c0 : tensor<?xf32>
+  %initS = tensor.empty(%d0) : tensor<?xf32>
+  %splat = linalg.fill ins(%s : f32) outs(%initS : tensor<?xf32>) -> tensor<?xf32>
+
+  %init = tensor.empty(%d0) : tensor<?xi1>
+  // CHECK: %[[INIT:.*]] = tensor.empty
+  // CHECK: linalg.generic
+  // CHECK-SAME: ins(%[[A]], %{{.*}}
+  %0 = arith.cmpf olt, %a, %splat : tensor<?xf32>
+  return %0 : tensor<?xi1>
+}
+
+// Test a binary elementwise op with a tensor and a zero-dimensional
+// (rank-0) tensor.
+// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor
+//  CHECK-SAME:   %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor<f32>
+func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor<f32>) -> tensor<4xf32> {
+  %c = tensor.extract %r0[] : tensor<f32>
+  %init = tensor.empty() : tensor<4xf32>
+  %splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: linalg.generic
+  // CHECK-SAME: ins(%[[T]], %{{.*}}
+  %0 = arith.addf %t, %splat : tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+
 // -----
 
 // Check indexing maps and iterator types for the rank > 0 case.



More information about the Mlir-commits mailing list