[Mlir-commits] [mlir] [mlir][linalg] Align `elementwise` builder to do type-conversion of input to result type (PR #190566)

Javed Absar llvmlistbot at llvm.org
Fri Apr 17 12:52:30 PDT 2026


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/190566

>From d02bbd8c57ad8882640aeb23dea92eacef3c63cb Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 4 Apr 2026 14:53:28 -0400
Subject: [PATCH 1/2] [mlir][linalg] Add 'compute_element_type' to
 linalg.elementwise op.

Adds optional `compute_element_type` attribute which specifies the element type
to be used for the computation, allowing the operation to be evaluated at a
different precision than the input and result element types. If specified,
the input operands are converted to `compute_element_type` prior to
computation (e.g. upcasting from `f16` to `f32`). If the result element
type differs from`compute_element_type`, the computed value is converted
back to the result type. If `compute_element_type` is not provided then
input and output types must match.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  23 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 149 +++++++++++++++++-
 .../elementwise/generalize-named-ops.mlir     |  25 +++
 .../Dialect/Linalg/elementwise/roundtrip.mlir |  16 ++
 4 files changed, 205 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5998f736ced34..542fb4fd01077 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,15 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
     operation kind can be unary (e.g. max), binary (e.g. add) or ternary
     (e.g. select).
 
+    The optional `compute_element_type` attribute specifies the element type
+    used for the computation, allowing the operation to be evaluated at a
+    different precision than the input and result element types. If specified,
+    the input operands are converted to `compute_element_type` prior to
+    computation (e.g. upcasting from `f16` to `f32`). If the result element
+    type differs from`compute_element_type`, the computed value is converted
+    back to the result type. If `compute_element_type` is not provided then
+    input and output types must match.
+
     By default, all indexing maps are identities. In the case of default
     indexing map, all input and output shapes must match. The number of dims in
     each of the identity maps is equal to the rank of the output type.
@@ -589,7 +598,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
       ElementwiseKindAttr:$kind,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      OptionalAttr<TypeAttr>:$compute_element_type
     );
 
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -619,12 +629,23 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
 
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
       /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
       static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
 
+      /// Return the `compute_element_type` inferred from result type.
+      Type getDefaultComputeElementType();
+
+      /// Return user-defined or default `compute_element_type`.
+      Type getEffectiveComputeElementType();
+
+      /// Returns `true` if inputs and result types do not match
+      /// and therefore the user must provide the compute type.
+      bool needsUserDefinedComputeElementType();
+
       /// Both user-specified and default indexing map will always depend on
       /// the current Op instance.
       static bool hasDynamicIndexingMaps() { return true; }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f9c8589683ba7..eeedb6f13f26f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -44,6 +44,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/LogicalResult.h"
@@ -4820,6 +4821,51 @@ ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
   return SmallVector<AffineMap>(numMaps, map);
 }
 
+bool ElementwiseOp::needsUserDefinedComputeElementType() {
+  auto getElemType = [](Type t) { return getElementTypeOrSelf(t); };
+
+  auto inputs = getInputs();
+  Type resultElemType = getElemType(getOutputs().front().getType());
+
+  switch (inputs.size()) {
+  case 1: {
+    // One input, must match result element type.
+    Type inputElemType = getElemType(inputs.front().getType());
+    return inputElemType != resultElemType;
+  }
+
+  case 2: {
+    // Both inputs must match result element type.
+    Type lhsElemType = getElemType(inputs[0].getType());
+    Type rhsElemType = getElemType(inputs[1].getType());
+    return lhsElemType != resultElemType || rhsElemType != resultElemType;
+  }
+
+  case 3: {
+    // Second and third inputs must match result element type.
+    // The first operand may be control / predicate / mask.
+    Type secondElemType = getElemType(inputs[1].getType());
+    Type thirdElemType = getElemType(inputs[2].getType());
+    return secondElemType != resultElemType || thirdElemType != resultElemType;
+  }
+  }
+
+  llvm_unreachable("unknown elementwise arity");
+}
+
+Type ElementwiseOp::getDefaultComputeElementType() {
+  auto out = getOutputs().front();
+  if (auto shaped = llvm::dyn_cast<ShapedType>(out.getType()))
+    return shaped.getElementType();
+  return {};
+}
+
+Type ElementwiseOp::getEffectiveComputeElementType() {
+  if (auto attr = getComputeElementTypeAttr())
+    return attr.getValue();
+  return getDefaultComputeElementType();
+}
+
 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
   // Expect e.g. `kind = #linalg.elemwise_kind<add>`
   Attribute attr;
@@ -4840,6 +4886,16 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(
       "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
 
+  // Parse the optional 'compute_element_type = <type>'.
+  if (succeeded(parser.parseOptionalKeyword("compute_element_type"))) {
+    if (parser.parseEqual())
+      return failure();
+    Type compType;
+    if (parser.parseType(compType))
+      return failure();
+    result.addAttribute("compute_element_type", TypeAttr::get(compType));
+  }
+
   // Parse optional `indexing_maps`
   SmallVector<Attribute, 3> indexingMapsAttr;
   Attribute mapAttr;
@@ -4896,8 +4952,16 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
 void ElementwiseOp::print(OpAsmPrinter &p) {
   p << " kind=";
   p.printAttribute(getKindAttr());
-  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
-                                           "indexing_maps"};
+
+  // print `compute element type` only if it cannot be inferred unambiguously.
+  if (needsUserDefinedComputeElementType()) {
+    Type computeType = getEffectiveComputeElementType();
+    p << " compute_element_type=" << computeType;
+  }
+
+  SmallVector<StringRef, 4> elidedAttrs = {
+      "operandSegmentSizes", "kind", "compute_element_type", "indexing_maps"};
+
   unsigned arity =
       getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
   unsigned numDims = getResultRank();
@@ -4914,8 +4978,8 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
                          elidedAttrs);
 }
 
-/// Implements the block region builder for the ElementwiseOp. This is called by
-/// 'fillStructuredOpRegion'.
+/// Implements the block region builder for the ElementwiseOp.
+/// This is called by 'fillStructuredOpRegion'.
 void ElementwiseOp::regionBuilder(
     ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
     function_ref<InFlightDiagnostic()> emitError) {
@@ -4947,14 +5011,70 @@ void ElementwiseOp::regionBuilder(
   SmallVector<Value> yields;
   Value result;
 
+  // Retreive or infer the `compute element type`.
+  Type resultType = block.getArguments().back().getType();
+  Type computeElementType;
+  for (NamedAttribute na : attrs) {
+    if (na.getName() == "compute_element_type") {
+      if (auto typeAttr = llvm::dyn_cast<TypeAttr>(na.getValue())) {
+        computeElementType = typeAttr.getValue();
+      } else {
+        emitError() << "expected TypeAttr for compute_element_type";
+        return;
+      }
+      break;
+    }
+  }
+  if (!computeElementType)
+    computeElementType = resultType; // inferred.
+
+  // Cast input value to dst type.
+  // Only same-kind casts are valid (float to float, int to int).
+  auto castToDstType = [&](Value v, Type dstType) -> Value {
+    Type srcType = v.getType();
+    if (srcType == dstType)
+      return v;
+
+    // Float -> Float
+    if (auto srcFloatType = dyn_cast<FloatType>(srcType)) {
+      if (auto dstFloatType = dyn_cast<FloatType>(dstType)) {
+        if (srcFloatType.getWidth() < dstFloatType.getWidth())
+          return arith::ExtFOp::create(b, b.getLoc(), dstType, v).getResult();
+        return arith::TruncFOp::create(b, b.getLoc(), dstType, v).getResult();
+      }
+    }
+
+    // Int -> Int
+    if (auto srcIntType = dyn_cast<IntegerType>(srcType)) {
+      if (auto dstIntType = dyn_cast<IntegerType>(dstType)) {
+        if (srcIntType.getWidth() < dstIntType.getWidth()) {
+          return srcIntType.isUnsigned()
+                     ? arith::ExtUIOp::create(b, b.getLoc(), dstType, v)
+                           .getResult()
+                     : arith::ExtSIOp::create(b, b.getLoc(), dstType, v)
+                           .getResult();
+        }
+        return arith::TruncIOp::create(b, b.getLoc(), dstType, v);
+      }
+    }
+
+    emitError() << "invalid cast from " << srcType << " to " << dstType
+                << " in linalg.elementwise";
+    return nullptr;
+  };
+
   if (arityGroup == ElementwiseArityGroup::Unary) {
-    result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+    Value in0 = castToDstType(block.getArgument(0), computeElementType);
+    result = castToDstType(helper.buildUnaryFn(kind.unaryFn, in0), resultType);
 
   } else if (arityGroup == ElementwiseArityGroup::Binary) {
-    result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
-                                  block.getArgument(1));
+    Value in0 = castToDstType(block.getArgument(0), computeElementType);
+    Value in1 = castToDstType(block.getArgument(1), computeElementType);
+    result = castToDstType(helper.buildBinaryFn(kind.binaryFn, in0, in1),
+                           resultType);
 
   } else if (arityGroup == ElementwiseArityGroup::Ternary) {
+    // ternary op (select) should be type casted.
     result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
                                    block.getArgument(1), block.getArgument(2));
 
@@ -4983,6 +5103,21 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+LogicalResult ElementwiseOp::verify() {
+  auto attr = getComputeElementTypeAttr();
+  // test validity of `compute element type` if user-defined.
+  if (!attr)
+    return success();
+  Type computeType = attr.getValue();
+
+  // Must be a scalar element type.
+  if (!computeType.isIntOrFloat())
+    return emitOpError() << "compute_element_type must be an integer or "
+                            "floating-point type, got "
+                         << computeType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..83ff3c6269d58 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,28 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
       outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
   return %r : tensor<8x16x32xf32>
 }
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_upcast_exp(%[[A:.+]]: tensor<8x16x32xf16>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+//
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[IN:.+]]: f16, %[[OUT:.+]]: f32)
+// CHECK:   %[[EXT:.+]] = arith.extf %[[IN]] : f16 to f32
+// CHECK:   %[[EXP:.+]] = math.exp %[[EXT]] : f32
+// CHECK:   linalg.yield %[[EXP]] : f32
+//
+func.func @unary_upcast_exp(%A : tensor<8x16x32xf16>, %B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+      kind = #linalg.elementwise_kind<exp>
+      compute_element_type = f32
+      ins(%A : tensor<8x16x32xf16>)
+      outs(%B : tensor<8x16x32xf32>)
+      -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..07a0c35729431 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,19 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
       outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
   return %r : tensor<1x2x3x4x5xi32>
 }
+
+// -----
+
+// CHECK: @unary_input_upcast
+// CHECK: %{{.*}} = linalg.elementwise
+// CHECK-SAME:        kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:        compute_element_type=f32
+//
+func.func @unary_input_upcast(%A : tensor<8x16x32xf16>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+      kind=#linalg.elementwise_kind<exp>
+      compute_element_type=f32
+      ins(%A : tensor<8x16x32xf16>)
+      outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}

>From 0c8ece97cd7d40334b072e786bd855a01080e636 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 17 Apr 2026 15:40:17 -0400
Subject: [PATCH 2/2] address review comments

Signed-off-by: Javed Absar <javed.absar at gmail.com>
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  35 ++----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 108 ++----------------
 .../elementwise/generalize-named-ops.mlir     |   1 -
 .../Dialect/Linalg/elementwise/roundtrip.mlir |  16 ---
 4 files changed, 19 insertions(+), 141 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 542fb4fd01077..a7f029e3c633b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -520,7 +520,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-                              mlir::ArrayRef<mlir::NamedAttribute>, 
+                              mlir::ArrayRef<mlir::NamedAttribute>,
                               function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       linalg::YieldOp::create(b, b.getLoc(), block.getArgument(0));
@@ -551,15 +551,6 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
     operation kind can be unary (e.g. max), binary (e.g. add) or ternary
     (e.g. select).
 
-    The optional `compute_element_type` attribute specifies the element type
-    used for the computation, allowing the operation to be evaluated at a
-    different precision than the input and result element types. If specified,
-    the input operands are converted to `compute_element_type` prior to
-    computation (e.g. upcasting from `f16` to `f32`). If the result element
-    type differs from`compute_element_type`, the computed value is converted
-    back to the result type. If `compute_element_type` is not provided then
-    input and output types must match.
-
     By default, all indexing maps are identities. In the case of default
     indexing map, all input and output shapes must match. The number of dims in
     each of the identity maps is equal to the rank of the output type.
@@ -574,6 +565,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
     The number of dims of the iterator-types are inferred from the rank of
     the result type.
 
+    The compute element type shall be the same as result type. For instance,
+    if input type is `f16` and result type `f32`, then input is upcast to
+    `f32` before doing the elementwise operation. Downcasting is not advised
+    but supported if user specifies, e.g. if input is `f32` and result `f16,
+    then input is truncated before doing the elementwise operation.
+
     Example:
 
     Defining a unary linalg.elementwise with default indexing-map:
@@ -598,8 +595,7 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
       ElementwiseKindAttr:$kind,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
-      OptionalAttr<TypeAttr>:$compute_element_type
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
     );
 
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -629,23 +625,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
 
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
-  let hasVerifier = 1;
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
       /// Get the arity enum corresponding to the kind of op, e.g. if arg is
       /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
       static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
 
-      /// Return the `compute_element_type` inferred from result type.
-      Type getDefaultComputeElementType();
-
-      /// Return user-defined or default `compute_element_type`.
-      Type getEffectiveComputeElementType();
-
-      /// Returns `true` if inputs and result types do not match
-      /// and therefore the user must provide the compute type.
-      bool needsUserDefinedComputeElementType();
-
       /// Both user-specified and default indexing map will always depend on
       /// the current Op instance.
       static bool hasDynamicIndexingMaps() { return true; }
@@ -1060,7 +1045,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", [
         (void)$_state.addRegion(),
         BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
       }]>
-      
+
     ];
     let hasCustomAssemblyFormat = 1;
     let hasFolder = 1;
@@ -1191,7 +1176,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
           attributes, BatchReduceMatmulOp::getRegionBuilder(),
           BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>
-      
+
     ];
     let hasCustomAssemblyFormat = 1;
     let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eeedb6f13f26f..0a4e6f6880be0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4821,51 +4821,6 @@ ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
   return SmallVector<AffineMap>(numMaps, map);
 }
 
-bool ElementwiseOp::needsUserDefinedComputeElementType() {
-  auto getElemType = [](Type t) { return getElementTypeOrSelf(t); };
-
-  auto inputs = getInputs();
-  Type resultElemType = getElemType(getOutputs().front().getType());
-
-  switch (inputs.size()) {
-  case 1: {
-    // One input, must match result element type.
-    Type inputElemType = getElemType(inputs.front().getType());
-    return inputElemType != resultElemType;
-  }
-
-  case 2: {
-    // Both inputs must match result element type.
-    Type lhsElemType = getElemType(inputs[0].getType());
-    Type rhsElemType = getElemType(inputs[1].getType());
-    return lhsElemType != resultElemType || rhsElemType != resultElemType;
-  }
-
-  case 3: {
-    // Second and third inputs must match result element type.
-    // The first operand may be control / predicate / mask.
-    Type secondElemType = getElemType(inputs[1].getType());
-    Type thirdElemType = getElemType(inputs[2].getType());
-    return secondElemType != resultElemType || thirdElemType != resultElemType;
-  }
-  }
-
-  llvm_unreachable("unknown elementwise arity");
-}
-
-Type ElementwiseOp::getDefaultComputeElementType() {
-  auto out = getOutputs().front();
-  if (auto shaped = llvm::dyn_cast<ShapedType>(out.getType()))
-    return shaped.getElementType();
-  return {};
-}
-
-Type ElementwiseOp::getEffectiveComputeElementType() {
-  if (auto attr = getComputeElementTypeAttr())
-    return attr.getValue();
-  return getDefaultComputeElementType();
-}
-
 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
   // Expect e.g. `kind = #linalg.elemwise_kind<add>`
   Attribute attr;
@@ -4886,16 +4841,6 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(
       "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
 
-  // Parse the optional 'compute_element_type = <type>'.
-  if (succeeded(parser.parseOptionalKeyword("compute_element_type"))) {
-    if (parser.parseEqual())
-      return failure();
-    Type compType;
-    if (parser.parseType(compType))
-      return failure();
-    result.addAttribute("compute_element_type", TypeAttr::get(compType));
-  }
-
   // Parse optional `indexing_maps`
   SmallVector<Attribute, 3> indexingMapsAttr;
   Attribute mapAttr;
@@ -4953,14 +4898,8 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
   p << " kind=";
   p.printAttribute(getKindAttr());
 
-  // print `compute element type` only if it cannot be inferred unambiguously.
-  if (needsUserDefinedComputeElementType()) {
-    Type computeType = getEffectiveComputeElementType();
-    p << " compute_element_type=" << computeType;
-  }
-
-  SmallVector<StringRef, 4> elidedAttrs = {
-      "operandSegmentSizes", "kind", "compute_element_type", "indexing_maps"};
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+                                           "indexing_maps"};
 
   unsigned arity =
       getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
@@ -5011,23 +4950,6 @@ void ElementwiseOp::regionBuilder(
   SmallVector<Value> yields;
   Value result;
 
-  // Retreive or infer the `compute element type`.
-  Type resultType = block.getArguments().back().getType();
-  Type computeElementType;
-  for (NamedAttribute na : attrs) {
-    if (na.getName() == "compute_element_type") {
-      if (auto typeAttr = llvm::dyn_cast<TypeAttr>(na.getValue())) {
-        computeElementType = typeAttr.getValue();
-      } else {
-        emitError() << "expected TypeAttr for compute_element_type";
-        return;
-      }
-      break;
-    }
-  }
-  if (!computeElementType)
-    computeElementType = resultType; // inferred.
-
   // Cast input value to dst type.
   // Only same-kind casts are valid (float to float, int to int).
   auto castToDstType = [&](Value v, Type dstType) -> Value {
@@ -5063,18 +4985,21 @@ void ElementwiseOp::regionBuilder(
     return nullptr;
   };
 
+  // Infer the compute element type from result type.
+  Type computeElementType = block.getArguments().back().getType();
+
+  // Create the linalg.generic body.
   if (arityGroup == ElementwiseArityGroup::Unary) {
     Value in0 = castToDstType(block.getArgument(0), computeElementType);
-    result = castToDstType(helper.buildUnaryFn(kind.unaryFn, in0), resultType);
+    result = helper.buildUnaryFn(kind.unaryFn, in0);
 
   } else if (arityGroup == ElementwiseArityGroup::Binary) {
     Value in0 = castToDstType(block.getArgument(0), computeElementType);
     Value in1 = castToDstType(block.getArgument(1), computeElementType);
-    result = castToDstType(helper.buildBinaryFn(kind.binaryFn, in0, in1),
-                           resultType);
+    result = helper.buildBinaryFn(kind.binaryFn, in0, in1);
 
   } else if (arityGroup == ElementwiseArityGroup::Ternary) {
-    // ternary op (select) should be type casted.
+    // ternary op (select) should not be type casted.
     result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
                                    block.getArgument(1), block.getArgument(2));
 
@@ -5103,21 +5028,6 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
-LogicalResult ElementwiseOp::verify() {
-  auto attr = getComputeElementTypeAttr();
-  // test validity of `compute element type` if user-defined.
-  if (!attr)
-    return success();
-  Type computeType = attr.getValue();
-
-  // Must be a scalar element type.
-  if (!computeType.isIntOrFloat())
-    return emitOpError() << "compute_element_type must be an integer or "
-                            "floating-point type, got "
-                         << computeType;
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index 83ff3c6269d58..483c518f86590 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -182,7 +182,6 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
 func.func @unary_upcast_exp(%A : tensor<8x16x32xf16>, %B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
   %r = linalg.elementwise
       kind = #linalg.elementwise_kind<exp>
-      compute_element_type = f32
       ins(%A : tensor<8x16x32xf16>)
       outs(%B : tensor<8x16x32xf32>)
       -> tensor<8x16x32xf32>
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 07a0c35729431..20ebdd992b5a1 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,19 +88,3 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
       outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
   return %r : tensor<1x2x3x4x5xi32>
 }
-
-// -----
-
-// CHECK: @unary_input_upcast
-// CHECK: %{{.*}} = linalg.elementwise
-// CHECK-SAME:        kind=#linalg.elementwise_kind<exp>
-// CHECK-SAME:        compute_element_type=f32
-//
-func.func @unary_input_upcast(%A : tensor<8x16x32xf16>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
-  %r = linalg.elementwise
-      kind=#linalg.elementwise_kind<exp>
-      compute_element_type=f32
-      ins(%A : tensor<8x16x32xf16>)
-      outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
-  return %r : tensor<8x16x32xf32>
-}



More information about the Mlir-commits mailing list