[Mlir-commits] [mlir] [mlir][linalg] Add 'compute_element_type' to linalg.elementwise op. (PR #190566)
Javed Absar
llvmlistbot at llvm.org
Sun Apr 5 18:04:56 PDT 2026
https://github.com/javedabsar1 created https://github.com/llvm/llvm-project/pull/190566
Adds optional `compute_element_type` attribute which specifies the element type to be used
for the computation, allowing the operation to be evaluated at a different precision than the
input and result element types. If specified, the input operands are converted to the type
prior to computation (e.g. upcasting from `f16` to `f32`). If the result element type differs
from`compute_element_type`, the computed value is converted back to the result type.
If `compute_element_type` is not provided then input and output types must match.
>From d02bbd8c57ad8882640aeb23dea92eacef3c63cb Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 4 Apr 2026 14:53:28 -0400
Subject: [PATCH] [mlir][linalg] Add 'compute_element_type' to
linalg.elementwise op.
Adds optional `compute_element_type` attribute which specifies the element type
to be used for the computation, allowing the operation to be evaluated at a
different precision than the input and result element types. If specified,
the input operands are converted to `compute_element_type` prior to
computation (e.g. upcasting from `f16` to `f32`). If the result element
type differs from`compute_element_type`, the computed value is converted
back to the result type. If `compute_element_type` is not provided then
input and output types must match.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 23 ++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 149 +++++++++++++++++-
.../elementwise/generalize-named-ops.mlir | 25 +++
.../Dialect/Linalg/elementwise/roundtrip.mlir | 16 ++
4 files changed, 205 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5998f736ced34..542fb4fd01077 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,15 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
(e.g. select).
+ The optional `compute_element_type` attribute specifies the element type
+ used for the computation, allowing the operation to be evaluated at a
+ different precision than the input and result element types. If specified,
+ the input operands are converted to `compute_element_type` prior to
+ computation (e.g. upcasting from `f16` to `f32`). If the result element
+ type differs from`compute_element_type`, the computed value is converted
+ back to the result type. If `compute_element_type` is not provided then
+ input and output types must match.
+
By default, all indexing maps are identities. In the case of default
indexing map, all input and output shapes must match. The number of dims in
each of the identity maps is equal to the rank of the output type.
@@ -589,7 +598,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ OptionalAttr<TypeAttr>:$compute_element_type
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -619,12 +629,23 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
+ /// Return the `compute_element_type` inferred from result type.
+ Type getDefaultComputeElementType();
+
+ /// Return user-defined or default `compute_element_type`.
+ Type getEffectiveComputeElementType();
+
+ /// Returns `true` if inputs and result types do not match
+ /// and therefore the user must provide the compute type.
+ bool needsUserDefinedComputeElementType();
+
/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
static bool hasDynamicIndexingMaps() { return true; }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f9c8589683ba7..eeedb6f13f26f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -44,6 +44,7 @@
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
@@ -4820,6 +4821,51 @@ ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
return SmallVector<AffineMap>(numMaps, map);
}
+bool ElementwiseOp::needsUserDefinedComputeElementType() {
+ auto getElemType = [](Type t) { return getElementTypeOrSelf(t); };
+
+ auto inputs = getInputs();
+ Type resultElemType = getElemType(getOutputs().front().getType());
+
+ switch (inputs.size()) {
+ case 1: {
+ // One input, must match result element type.
+ Type inputElemType = getElemType(inputs.front().getType());
+ return inputElemType != resultElemType;
+ }
+
+ case 2: {
+ // Both inputs must match result element type.
+ Type lhsElemType = getElemType(inputs[0].getType());
+ Type rhsElemType = getElemType(inputs[1].getType());
+ return lhsElemType != resultElemType || rhsElemType != resultElemType;
+ }
+
+ case 3: {
+ // Second and third inputs must match result element type.
+ // The first operand may be control / predicate / mask.
+ Type secondElemType = getElemType(inputs[1].getType());
+ Type thirdElemType = getElemType(inputs[2].getType());
+ return secondElemType != resultElemType || thirdElemType != resultElemType;
+ }
+ }
+
+ llvm_unreachable("unknown elementwise arity");
+}
+
+Type ElementwiseOp::getDefaultComputeElementType() {
+ auto out = getOutputs().front();
+ if (auto shaped = llvm::dyn_cast<ShapedType>(out.getType()))
+ return shaped.getElementType();
+ return {};
+}
+
+Type ElementwiseOp::getEffectiveComputeElementType() {
+ if (auto attr = getComputeElementTypeAttr())
+ return attr.getValue();
+ return getDefaultComputeElementType();
+}
+
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
// Expect e.g. `kind = #linalg.elemwise_kind<add>`
Attribute attr;
@@ -4840,6 +4886,16 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(
"kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
+ // Parse the optional 'compute_element_type = <type>'.
+ if (succeeded(parser.parseOptionalKeyword("compute_element_type"))) {
+ if (parser.parseEqual())
+ return failure();
+ Type compType;
+ if (parser.parseType(compType))
+ return failure();
+ result.addAttribute("compute_element_type", TypeAttr::get(compType));
+ }
+
// Parse optional `indexing_maps`
SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
@@ -4896,8 +4952,16 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
void ElementwiseOp::print(OpAsmPrinter &p) {
p << " kind=";
p.printAttribute(getKindAttr());
- SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
- "indexing_maps"};
+
+ // print `compute element type` only if it cannot be inferred unambiguously.
+ if (needsUserDefinedComputeElementType()) {
+ Type computeType = getEffectiveComputeElementType();
+ p << " compute_element_type=" << computeType;
+ }
+
+ SmallVector<StringRef, 4> elidedAttrs = {
+ "operandSegmentSizes", "kind", "compute_element_type", "indexing_maps"};
+
unsigned arity =
getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
unsigned numDims = getResultRank();
@@ -4914,8 +4978,8 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
elidedAttrs);
}
-/// Implements the block region builder for the ElementwiseOp. This is called by
-/// 'fillStructuredOpRegion'.
+/// Implements the block region builder for the ElementwiseOp.
+/// This is called by 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError) {
@@ -4947,14 +5011,70 @@ void ElementwiseOp::regionBuilder(
SmallVector<Value> yields;
Value result;
+ // Retreive or infer the `compute element type`.
+ Type resultType = block.getArguments().back().getType();
+ Type computeElementType;
+ for (NamedAttribute na : attrs) {
+ if (na.getName() == "compute_element_type") {
+ if (auto typeAttr = llvm::dyn_cast<TypeAttr>(na.getValue())) {
+ computeElementType = typeAttr.getValue();
+ } else {
+ emitError() << "expected TypeAttr for compute_element_type";
+ return;
+ }
+ break;
+ }
+ }
+ if (!computeElementType)
+ computeElementType = resultType; // inferred.
+
+ // Cast input value to dst type.
+ // Only same-kind casts are valid (float to float, int to int).
+ auto castToDstType = [&](Value v, Type dstType) -> Value {
+ Type srcType = v.getType();
+ if (srcType == dstType)
+ return v;
+
+ // Float -> Float
+ if (auto srcFloatType = dyn_cast<FloatType>(srcType)) {
+ if (auto dstFloatType = dyn_cast<FloatType>(dstType)) {
+ if (srcFloatType.getWidth() < dstFloatType.getWidth())
+ return arith::ExtFOp::create(b, b.getLoc(), dstType, v).getResult();
+ return arith::TruncFOp::create(b, b.getLoc(), dstType, v).getResult();
+ }
+ }
+
+ // Int -> Int
+ if (auto srcIntType = dyn_cast<IntegerType>(srcType)) {
+ if (auto dstIntType = dyn_cast<IntegerType>(dstType)) {
+ if (srcIntType.getWidth() < dstIntType.getWidth()) {
+ return srcIntType.isUnsigned()
+ ? arith::ExtUIOp::create(b, b.getLoc(), dstType, v)
+ .getResult()
+ : arith::ExtSIOp::create(b, b.getLoc(), dstType, v)
+ .getResult();
+ }
+ return arith::TruncIOp::create(b, b.getLoc(), dstType, v);
+ }
+ }
+
+ emitError() << "invalid cast from " << srcType << " to " << dstType
+ << " in linalg.elementwise";
+ return nullptr;
+ };
+
if (arityGroup == ElementwiseArityGroup::Unary) {
- result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+ Value in0 = castToDstType(block.getArgument(0), computeElementType);
+ result = castToDstType(helper.buildUnaryFn(kind.unaryFn, in0), resultType);
} else if (arityGroup == ElementwiseArityGroup::Binary) {
- result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
- block.getArgument(1));
+ Value in0 = castToDstType(block.getArgument(0), computeElementType);
+ Value in1 = castToDstType(block.getArgument(1), computeElementType);
+ result = castToDstType(helper.buildBinaryFn(kind.binaryFn, in0, in1),
+ resultType);
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
+ // ternary op (select) should be type casted.
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
block.getArgument(1), block.getArgument(2));
@@ -4983,6 +5103,21 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+LogicalResult ElementwiseOp::verify() {
+ auto attr = getComputeElementTypeAttr();
+ // test validity of `compute element type` if user-defined.
+ if (!attr)
+ return success();
+ Type computeType = attr.getValue();
+
+ // Must be a scalar element type.
+ if (!computeType.isIntOrFloat())
+ return emitOpError() << "compute_element_type must be an integer or "
+ "floating-point type, got "
+ << computeType;
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..83ff3c6269d58 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,28 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_upcast_exp(%[[A:.+]]: tensor<8x16x32xf16>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+//
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[IN:.+]]: f16, %[[OUT:.+]]: f32)
+// CHECK: %[[EXT:.+]] = arith.extf %[[IN]] : f16 to f32
+// CHECK: %[[EXP:.+]] = math.exp %[[EXT]] : f32
+// CHECK: linalg.yield %[[EXP]] : f32
+//
+func.func @unary_upcast_exp(%A : tensor<8x16x32xf16>, %B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind = #linalg.elementwise_kind<exp>
+ compute_element_type = f32
+ ins(%A : tensor<8x16x32xf16>)
+ outs(%B : tensor<8x16x32xf32>)
+ -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..07a0c35729431 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,19 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}
+
+// -----
+
+// CHECK: @unary_input_upcast
+// CHECK: %{{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: compute_element_type=f32
+//
+func.func @unary_input_upcast(%A : tensor<8x16x32xf16>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<exp>
+ compute_element_type=f32
+ ins(%A : tensor<8x16x32xf16>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
More information about the Mlir-commits
mailing list