[Mlir-commits] [mlir] [mlir][linalg] Add `comp-type` to new elementwise-op. (PR #131542)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 16 15:28:50 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Javed Absar (javedabsar1)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/131542.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+6-2)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+25-6)
- (modified) mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir (+24)
- (modified) mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir (+38)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 308e39a9a51e1..af85daca1c078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
The number of dims of the iterator-types are inferred from the rank of
the result type.
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the result.
+
Example:
Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
- ins(%x : tensor<4x16x8xf32>)
+ ins(%x : tensor<4x16x8xf16>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```
@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..0ffa259023faf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
SmallVector<Value> yields;
Value result;
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
if (arityGroup == ElementwiseArityGroup::Unary) {
- result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+ Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
+ block.getArgument(0));
+ result = helper.buildUnaryFn(kind.unaryFn, val0);
} else if (arityGroup == ElementwiseArityGroup::Binary) {
- result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
- block.getArgument(1));
+ Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ result = helper.buildBinaryFn(kind.binaryFn, val0, val1);
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
- result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
- block.getArgument(1), block.getArgument(2));
-
+ // select op's select-arg (block arg 0) must remain bool.
+ Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+ block.getArgument(1));
+ Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+ block.getArgument(2));
+ result =
+ helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
} else
assert(false && "found unhandled category in elemwise");
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..19fb0e61d450b 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[MUL]] : f32
+//
+func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..0bce89ca378a4 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}
+
+// -----
+
+// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
+ %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<div>
+ ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+
+
+// -----
+
+// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: {cast = #linalg.type_fn<cast_signed>}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
+//
+func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
+ %0 = linalg.elementwise
+ kind=#linalg.elementwise_kind<add>
+ {cast = #linalg.type_fn<cast_signed>}
+ ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
+ outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
+ return %0 : tensor<16x8xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131542
More information about the Mlir-commits
mailing list