[Mlir-commits] [mlir] [mlir][linalg] Add `comp-type` to new elementwise-op. (PR #131542)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 16 15:28:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/131542.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+6-2) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+25-6) 
- (modified) mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir (+24) 
- (modified) mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir (+38) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 308e39a9a51e1..af85daca1c078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
     The number of dims of the iterator-types are inferred from the rank of
     the result type.
 
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the result.
+
     Example:
 
     Defining a unary linalg.elemwise with default indexing-map:
     ```mlir
     %exp = linalg.elemwise
         kind=#linalg.elemwise_kind<exp>
-        ins(%x : tensor<4x16x8xf32>)
+        ins(%x : tensor<4x16x8xf16>)
         outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
     ```
 
@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
       ElementwiseKindAttr:$kind,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
 
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..0ffa259023faf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   SmallVector<Value> yields;
   Value result;
 
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
   if (arityGroup == ElementwiseArityGroup::Unary) {
-    result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+    Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
+                                    block.getArgument(0));
+    result = helper.buildUnaryFn(kind.unaryFn, val0);
 
   } else if (arityGroup == ElementwiseArityGroup::Binary) {
-    result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
-                                  block.getArgument(1));
+    Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+    Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+    result = helper.buildBinaryFn(kind.binaryFn, val0, val1);
 
   } else if (arityGroup == ElementwiseArityGroup::Ternary) {
-    result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
-                                   block.getArgument(1), block.getArgument(2));
-
+    // select op's select-arg (block arg 0) must remain bool.
+    Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+                                    block.getArgument(1));
+    Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+                                    block.getArgument(2));
+    result =
+        helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
   } else
     assert(false && "found unhandled category in elemwise");
 
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..19fb0e61d450b 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
       outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
   return %r : tensor<8x16x32xf32>
 }
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK:   %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[MUL]] : f32
+//
+func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+      kind=#linalg.elementwise_kind<mul>
+      ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+      outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..0bce89ca378a4 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
       outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
   return %r : tensor<1x2x3x4x5xi32>
 }
+
+// -----
+
+// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME:                    %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<div>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
+                      %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+      kind=#linalg.elementwise_kind<div>
+      ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+      outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+
+
+// -----
+
+// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
+// CHECK-SAME:           %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
+// CHECK:  {{.*}} = linalg.elementwise
+// CHECK-SAME:          kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:          {cast = #linalg.type_fn<cast_signed>}
+// CHECK-SAME:          ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
+// CHECK-SAME:          outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
+//
+func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
+    %0 = linalg.elementwise
+            kind=#linalg.elementwise_kind<add>
+            {cast = #linalg.type_fn<cast_signed>}
+            ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
+            outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
+    return %0 : tensor<16x8xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/131542


More information about the Mlir-commits mailing list