[Mlir-commits] [mlir] 9801a0f - [mlir] Add helper to check elementwise-mappable ops with tensors and scalars (#154872)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 25 11:02:00 PDT 2025


Author: Samarth Narang
Date: 2025-08-25T14:01:57-04:00
New Revision: 9801a0f62e49cbd81ce8352ab140af7c240e51ba

URL: https://github.com/llvm/llvm-project/commit/9801a0f62e49cbd81ce8352ab140af7c240e51ba
DIFF: https://github.com/llvm/llvm-project/commit/9801a0f62e49cbd81ce8352ab140af7c240e51ba.diff

LOG: [mlir] Add helper to check elementwise-mappable ops with tensors and scalars (#154872)

This patch introduces a more general helper for identifying
elementwise-mappable operations. The existing utility,
`isElementwiseMappableOpOnRankedTensors`, only accepted operations when
all operands were ranked tensors. In practice, many elementwise
operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are
either ranked tensors or “scalar-like” types.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
    mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..baf4083d15b0c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,26 @@ namespace mlir {
 
 using namespace mlir;
 
+static inline bool isScalarLike(Type t) {
+  return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    return !(isa<RankedTensorType>(t) || isScalarLike(t));
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
       return rewriter.notifyMatchFailure(
           op, "requires elementwise op on ranked tensors");
 
-    auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
-    SmallVector<AffineMap, 3> indexingMaps(
-        op->getNumResults() + op->getNumOperands(),
-        rewriter.getMultiDimIdentityMap(rank));
-    SmallVector<utils::IteratorType, 6> iteratorTypes(
+    auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+    auto rank = resTy.getRank();
+
+    // Maps: identity for tensors (rank > 0), scalar map for scalars.
+    AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+                                         /*results=*/{}, rewriter.getContext());
+    AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+    // Match phase.
+    SmallVector<bool> isScalarOperand;
+    isScalarOperand.reserve(op->getNumOperands());
+    for (Type ty : op->getOperandTypes()) {
+      if (isScalarLike(ty))
+        isScalarOperand.push_back(true);
+      else if (auto rt = dyn_cast<RankedTensorType>(ty))
+        isScalarOperand.push_back(false);
+      else
+        return rewriter.notifyMatchFailure(
+            op,
+            "unsupported operand type (expected scalar-like or ranked tensor)");
+    }
+
+    // Create indexing maps.
+    SmallVector<AffineMap> indexingMaps;
+    indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+    for (bool isScalar : isScalarOperand)
+      indexingMaps.push_back(isScalar ? scalarMap : idMap);
+
+    indexingMaps.append(op->getNumResults(), idMap);
+
+    SmallVector<utils::IteratorType> iteratorTypes(
         rank, utils::IteratorType::parallel);
-    auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
+    SmallVector<Value> outputs =
+        getOrCreateOperandsMatchingResultTypes(rewriter, op);
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
         op, /*resultTensorTypes=*/op->getResultTypes(),
         /*inputs=*/op->getOperands(),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
         /*iteratorTypes=*/iteratorTypes,
         /*bodyBuilder=*/
         [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
-          auto resultTypes = llvm::to_vector<6>(
+          SmallVector<Type> resultEltTys = llvm::to_vector<6>(
               llvm::map_range(op->getResultTypes(), [](Type type) {
                 return cast<TensorType>(type).getElementType();
               }));
-          auto *scalarOp =
+          Operation *scalarOp =
               builder.create(loc, op->getName().getIdentifier(),
                              regionArgs.take_front(op->getNumOperands()),
-                             resultTypes, op->getAttrs());
+                             resultEltTys, op->getAttrs());
           linalg::YieldOp::create(builder, loc, scalarOp->getResults());
         });
     return success();

diff  --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index a6552e0a5264e..cc7a5469ba73b 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -108,3 +108,69 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
   return %0 : tensor<4x?x?x8x2x?xi1>
 }
 
+// -----
+
+// Check a mix of scalar and tensor input.
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> 
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> 
+// CHECK-LABEL: func @scalar_plus_tensor
+func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK: %[[GEN:.*]] = linalg.generic
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
+  // CHECK:   "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield {{.*}} : f32
+  // CHECK: } -> tensor<?x?xf32>
+  %0 = "test.elementwise_mappable"(%arg0, %arg1)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+// This test exercises the case where an elementwise op has two scalar-like
+// operands and one ranked tensor operand. In this example, we chain two
+// `test.elementwise_mappable` calls:
+//   %0 = f(%s1, %t)
+//   %1 = f(%s2, %0)
+// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @scalar_tensor_scalar
+func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
+  // First generic.
+  // CHECK: %[[GEN0:.*]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
+  // CHECK:   %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield %[[APPLY0]] : f32
+  // CHECK: } -> tensor<?x?xf32>
+
+  // Second generic.
+  // CHECK: %[[GEN1:.*]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
+  // CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
+  // CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
+  // CHECK:   %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
+  // CHECK:   linalg.yield %[[APPLY1]] : f32
+  // CHECK: } -> tensor<?x?xf32>
+  // CHECK: return %[[GEN1]] : tensor<?x?xf32>
+  %0 = "test.elementwise_mappable"(%s1, %t)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = "test.elementwise_mappable"(%s2, %0)
+       : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// ----
+// CHECK-LABEL: func @negative_scalar_only_eltwise
+// CHECK-NOT: linalg
+func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 {
+  %0 = arith.addf %a, %b : f32
+  return %0 : f32
+}


        


More information about the Mlir-commits mailing list