[Mlir-commits] [mlir] [mlir][linalg][nfc] Code refactoring of LinalgOps.cpp (PR #164274)
Javed Absar
llvmlistbot at llvm.org
Mon Oct 20 08:58:08 PDT 2025
https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/164274
>From 6a6c2ce34283889c03fbc801b7c79d1ea995546a Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Mon, 20 Oct 2025 10:52:37 -0400
Subject: [PATCH 1/2] [mlir][linalg][nfc] Code refactoring of LinalgOps.cpp
The RegionBuilder has grown quite big with declaration and definitions
all in one blob. And as the TODO in the existing code rightly suggests,
it would be cleaner to extract it out.
---
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 293 +-----------------
.../Dialect/Linalg/IR/RegionBuilderHelper.cpp | 237 ++++++++++++++
.../Dialect/Linalg/IR/RegionBuilderHelper.h | 148 +++++++++
4 files changed, 388 insertions(+), 291 deletions(-)
create mode 100644 mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp
create mode 100644 mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index ec433284e17ad..75f65f39f2371 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
LinalgOps.cpp
LinalgDialect.cpp
ValueBoundsOpInterfaceImpl.cpp
+ RegionBuilderHelper.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b0c8cbd..f029613a280e1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -9,7 +9,6 @@
// This file implements the Linalg operations.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/AsmParser/AsmParser.h"
@@ -50,6 +49,8 @@
#include <cassert>
#include <optional>
+#include "RegionBuilderHelper.h"
+
using namespace mlir;
using namespace mlir::linalg;
@@ -411,296 +412,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
// Region is elided.
}
-//===----------------------------------------------------------------------===//
-// Region builder helper.
-// TODO: Move this to a utility library.
-// The public methods on this class are referenced directly from generated code.
-// Helper build the unary, binary, and type conversion functions defined by the
-// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
-// class.
-//
-// Implementations of the math functions must be polymorphic over numeric types,
-// internally performing necessary casts. If the function application makes no
-// sense, then the only recourse is to assert and return nullptr. This can be
-// extended later if it becomes possible to fail construction of the region. The
-// invariant should be enforced at a higher level.
-//
-// TODO: These helpers are currently type polymorphic over the class of integer
-// and floating point types, but they will not internally cast within bit
-// widths of a class (mixed precision such as i8->i32) or across classes
-// (i.e. mixed float and integer). Many such combinations are ambiguous or need
-// to be handled with care and work is being considered to extend the op
-// language to make such cases explicit. In the mean-time, violating this will
-// fail verification, which is deemed acceptable.
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-class RegionBuilderHelper {
-public:
- RegionBuilderHelper(OpBuilder &builder, Block &block)
- : builder(builder), block(block) {}
-
- // Build the unary functions defined by OpDSL.
- Value buildUnaryFn(UnaryFn unaryFn, Value arg,
- function_ref<InFlightDiagnostic()> emitError = {}) {
- if (!isFloatingPoint(arg)) {
- if (emitError) {
- emitError() << "unsupported non numeric type";
- return nullptr;
- }
- llvm_unreachable("unsupported non numeric type");
- }
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- switch (unaryFn) {
- case UnaryFn::exp:
- return math::ExpOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::log:
- return math::LogOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::abs:
- return math::AbsFOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::ceil:
- return math::CeilOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::floor:
- return math::FloorOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::negf:
- return arith::NegFOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::reciprocal: {
- Attribute oneAttr = builder.getOneAttr(arg.getType());
- auto one = arith::ConstantOp::create(builder, arg.getLoc(),
- ::cast<TypedAttr>(oneAttr));
- return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
- }
- case UnaryFn::round:
- return math::RoundOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::sqrt:
- return math::SqrtOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::rsqrt:
- return math::RsqrtOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::square:
- return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
- case UnaryFn::tanh:
- return math::TanhOp::create(builder, arg.getLoc(), arg);
- case UnaryFn::erf:
- return math::ErfOp::create(builder, arg.getLoc(), arg);
- }
- if (emitError) {
- emitError() << "unsupported unary function";
- return nullptr;
- }
- llvm_unreachable("unsupported unary function");
- }
-
- // Build the binary functions defined by OpDSL.
- // If emitError is provided, an error will be emitted if the operation is not
- // supported and a nullptr will be returned, otherwise an assertion will be
- // raised.
- Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
- function_ref<InFlightDiagnostic()> emitError = {}) {
- bool allComplex = isComplex(arg0) && isComplex(arg1);
- bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
- bool allInteger = isInteger(arg0) && isInteger(arg1);
- bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
- arg1.getType().getIntOrFloatBitWidth() == 1;
- if (!allComplex && !allFloatingPoint && !allInteger) {
- if (emitError) {
- emitError()
- << "Cannot build binary Linalg operation: expects allComplex, "
- "allFloatingPoint, or allInteger, got "
- << arg0.getType() << " and " << arg1.getType();
- return nullptr;
- }
- llvm_unreachable("unsupported non numeric type");
- }
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- switch (binaryFn) {
- case BinaryFn::add:
- if (allComplex)
- return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allBool)
- return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::sub:
- if (allComplex)
- return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allBool) {
- if (emitError) {
- emitError() << "unsupported operation: sub with bools";
- return nullptr;
- }
- llvm_unreachable("unsupported operation: sub with bools");
- }
- return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::mul:
- if (allComplex)
- return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allBool)
- return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::div:
- if (allComplex)
- return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
- if (allBool) {
- if (emitError) {
- emitError() << "unsupported operation: div with bools";
- return nullptr;
- }
- llvm_unreachable("unsupported operation: div with bools");
- }
- return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::div_unsigned:
- if (!allInteger || allBool) {
- if (emitError) {
- emitError() << "unsupported operation: unsigned div not on uint";
- return nullptr;
- }
- llvm_unreachable("unsupported operation: unsigned div not on uint");
- }
- return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::max_signed:
- assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::min_signed:
- assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::max_unsigned:
- assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::min_unsigned:
- assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
- return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
- case BinaryFn::powf:
- assert(allFloatingPoint);
- return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
- }
- if (emitError) {
- emitError() << "unsupported binary function";
- return nullptr;
- }
- llvm_unreachable("unsupported binary function");
- }
-
- // Build the ternary functions defined by OpDSL.
- Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
- function_ref<InFlightDiagnostic()> emitError = {}) {
- bool headBool =
- isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
- bool tailFloatingPoint =
- isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
- bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- switch (ternaryFn) {
- case TernaryFn::select:
- if (!headBool && !(tailFloatingPoint || tailInteger))
- llvm_unreachable("unsupported non numeric type");
- return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
- }
- if (emitError) {
- emitError() << "unsupported ternary function";
- return nullptr;
- }
- llvm_unreachable("unsupported ternary function");
- }
-
- // Build the type functions defined by OpDSL.
- Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
- function_ref<InFlightDiagnostic()> emitError = {}) {
- switch (typeFn) {
- case TypeFn::cast_signed:
- return cast(toType, operand, false);
- case TypeFn::cast_unsigned:
- return cast(toType, operand, true);
- }
- if (emitError) {
- emitError() << "unsupported type conversion function";
- return nullptr;
- }
- llvm_unreachable("unsupported type conversion function");
- }
-
- void yieldOutputs(ValueRange values) {
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- Location loc = builder.getUnknownLoc();
- YieldOp::create(builder, loc, values);
- }
-
- Value constant(const std::string &value) {
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- Location loc = builder.getUnknownLoc();
- Attribute valueAttr = parseAttribute(value, builder.getContext());
- return arith::ConstantOp::create(builder, loc,
- ::cast<TypedAttr>(valueAttr));
- }
-
- Value index(int64_t dim) {
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- return IndexOp::create(builder, builder.getUnknownLoc(), dim);
- }
-
- Type getIntegerType(unsigned width) {
- return IntegerType::get(builder.getContext(), width);
- }
-
- Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
- Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
-
-private:
- // Generates operations to cast the given operand to a specified type.
- // If the cast cannot be performed, a warning will be issued and the
- // operand returned as-is (which will presumably yield a verification
- // issue downstream).
- Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPointToEnd(&block);
- auto loc = operand.getLoc();
- if (isa<UnknownLoc>(loc)) {
- if (operand.getDefiningOp())
- loc = operand.getDefiningOp()->getLoc();
- else if (operand.getParentBlock() &&
- operand.getParentBlock()->getParentOp())
- loc = operand.getParentBlock()->getParentOp()->getLoc();
- }
- return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
- }
-
- bool isComplex(Value value) {
- return llvm::isa<ComplexType>(value.getType());
- }
- bool isFloatingPoint(Value value) {
- return llvm::isa<FloatType>(value.getType());
- }
- bool isInteger(Value value) {
- return llvm::isa<IntegerType>(value.getType());
- }
-
- OpBuilder &builder;
- Block █
-};
-
-} // namespace
-
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp
new file mode 100644
index 0000000000000..b1b6dd8afe8c5
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp
@@ -0,0 +1,237 @@
+//===- RegionBuilderHelper.cpp - Region Builder Helper class -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of RegionBuilderHelper class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "RegionBuilderHelper.h"
+
+namespace mlir {
+namespace linalg {
+
+Value RegionBuilderHelper::buildUnaryFn(
+ UnaryFn unaryFn, Value arg, function_ref<InFlightDiagnostic()> emitError) {
+ if (!isFloatingPoint(arg)) {
+ if (emitError) {
+ emitError() << "unsupported non numeric type";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported non numeric type");
+ }
+
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ switch (unaryFn) {
+ case UnaryFn::exp:
+ return math::ExpOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::log:
+ return math::LogOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::abs:
+ return math::AbsFOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::ceil:
+ return math::CeilOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::floor:
+ return math::FloorOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::negf:
+ return arith::NegFOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::reciprocal: {
+ Attribute oneAttr = builder.getOneAttr(arg.getType());
+ auto one = arith::ConstantOp::create(builder, arg.getLoc(),
+ llvm::cast<TypedAttr>(oneAttr));
+ return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
+ }
+ case UnaryFn::round:
+ return math::RoundOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::sqrt:
+ return math::SqrtOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::rsqrt:
+ return math::RsqrtOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::square:
+ return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
+ case UnaryFn::tanh:
+ return math::TanhOp::create(builder, arg.getLoc(), arg);
+ case UnaryFn::erf:
+ return math::ErfOp::create(builder, arg.getLoc(), arg);
+ }
+
+ if (emitError) {
+ emitError() << "unsupported unary function";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported unary function");
+}
+
+Value RegionBuilderHelper::buildBinaryFn(
+ BinaryFn binaryFn, Value arg0, Value arg1,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ bool allComplex = isComplex(arg0) && isComplex(arg1);
+ bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+ bool allInteger = isInteger(arg0) && isInteger(arg1);
+ bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+ arg1.getType().getIntOrFloatBitWidth() == 1;
+
+ if (!allComplex && !allFloatingPoint && !allInteger) {
+ if (emitError) {
+ emitError()
+ << "Cannot build binary Linalg operation: expects allComplex, "
+ "allFloatingPoint, or allInteger, got "
+ << arg0.getType() << " and " << arg1.getType();
+ return nullptr;
+ }
+ llvm_unreachable("unsupported non numeric type");
+ }
+
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ switch (binaryFn) {
+ case BinaryFn::add:
+ if (allComplex)
+ return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::sub:
+ if (allComplex)
+ return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: sub with bools";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: sub with bools");
+ }
+ return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::mul:
+ if (allComplex)
+ return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::div:
+ if (allComplex)
+ return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: div with bools";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: div with bools");
+ }
+ return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::div_unsigned:
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned div not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned div not on uint");
+ }
+ return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max_signed:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min_signed:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max_unsigned:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min_unsigned:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ case BinaryFn::powf:
+ assert(allFloatingPoint);
+ return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ }
+
+ if (emitError) {
+ emitError() << "unsupported binary function";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported binary function");
+}
+
+Value RegionBuilderHelper::buildTernaryFn(
+ TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+ function_ref<InFlightDiagnostic()> emitError) {
+ bool headBool =
+ isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+ bool tailFloatingPoint =
+ isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+ bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+
+ switch (ternaryFn) {
+ case TernaryFn::select:
+ if (!headBool && !(tailFloatingPoint || tailInteger))
+ llvm_unreachable("unsupported non numeric type");
+ return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
+ }
+
+ if (emitError) {
+ emitError() << "unsupported ternary function";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported ternary function");
+}
+
+Value RegionBuilderHelper::buildTypeFn(
+ TypeFn typeFn, Type toType, Value operand,
+ function_ref<InFlightDiagnostic()> emitError) {
+ switch (typeFn) {
+ case TypeFn::cast_signed:
+ return cast(toType, operand, false);
+ case TypeFn::cast_unsigned:
+ return cast(toType, operand, true);
+ }
+
+ if (emitError) {
+ emitError() << "unsupported type conversion function";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported type conversion function");
+}
+
+void RegionBuilderHelper::yieldOutputs(ValueRange values) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ Location loc = builder.getUnknownLoc();
+ YieldOp::create(builder, loc, values);
+}
+
+Value RegionBuilderHelper::constant(const std::string &value) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ Location loc = builder.getUnknownLoc();
+ Attribute valueAttr = parseAttribute(value, builder.getContext());
+ return arith::ConstantOp::create(builder, loc,
+ llvm::cast<TypedAttr>(valueAttr));
+}
+
+} // namespace linalg
+
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
new file mode 100644
index 0000000000000..de0edac8a6091
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
@@ -0,0 +1,148 @@
+//===- RegionBuilderHelper.h - Region-Builder-Helper class declaration ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helper builds the unary, binary, and type conversion functions defined by the
+// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
+// class.
+//
+// Implementations of the math functions must be polymorphic over numeric types,
+// internally performing necessary casts. If the function application makes no
+// sense, then the only recourse is to assert and return nullptr. This can be
+// extended later if it becomes possible to fail construction of the region. The
+// invariant should be enforced at a higher level.
+//
+// TODO: These helpers are currently type polymorphic over the class of integer
+// and floating point types, but they will not internally cast within bit
+// widths of a class (mixed precision such as i8->i32) or across classes
+// (i.e. mixed float and integer). Many such combinations are ambiguous or need
+// to be handled with care and work is being considered to extend the op
+// language to make such cases explicit. In the mean-time, violating this will
+// fail verification, which is deemed acceptable.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LINALG_REGION_BUILDER_HELPER_H
+#define MLIR_LINALG_REGION_BUILDER_HELPER_H
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <optional>
+
+namespace mlir {
+namespace linalg {
+
+class RegionBuilderHelper {
+public:
+ RegionBuilderHelper(OpBuilder &builder, Block &block)
+ : builder(builder), block(block) {}
+
+ // Build the unary functions.
+ Value buildUnaryFn(UnaryFn unaryFn, Value arg,
+ function_ref<InFlightDiagnostic()> emitError = {});
+
+ // Build the binary functions.
+ // If emitError is provided, an error will be emitted if the operation is not
+ // supported and a nullptr will be returned, otherwise an assertion is raised.
+ Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
+ function_ref<InFlightDiagnostic()> emitError = {});
+
+ // Build the ternary functions defined by OpDSL.
+ Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+ function_ref<InFlightDiagnostic()> emitError = {});
+
+ // Build the type functions defined by OpDSL.
+ Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
+ function_ref<InFlightDiagnostic()> emitError = {});
+
+ // Create a `yieldOp` to yield `values` passed in arg.
+ void yieldOutputs(ValueRange values);
+
+ // Create a constant op with value parsed from string `value`.
+ Value constant(const std::string &value);
+
+ // Create an `index` op to extract iteration index `dim`.
+ Value index(int64_t dim) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ return IndexOp::create(builder, builder.getUnknownLoc(), dim);
+ }
+
+ // Create an integer of size `width`.
+ Type getIntegerType(unsigned width) {
+ return IntegerType::get(builder.getContext(), width);
+ }
+
+ Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+ Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
+
+private:
+ // Generates operations to cast the given operand to a specified type.
+ // If the cast cannot be performed, a warning will be issued and the
+ // operand returned as-is (which will presumably yield a verification
+ // issue downstream).
+ Value cast(Type toType, Value operand, bool isUnsignedCast) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ auto loc = operand.getLoc();
+ if (isa<UnknownLoc>(loc)) {
+ if (operand.getDefiningOp())
+ loc = operand.getDefiningOp()->getLoc();
+ else if (operand.getParentBlock() &&
+ operand.getParentBlock()->getParentOp())
+ loc = operand.getParentBlock()->getParentOp()->getLoc();
+ }
+ return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
+ }
+
+ bool isComplex(Value value) {
+ return llvm::isa<ComplexType>(value.getType());
+ }
+ bool isFloatingPoint(Value value) {
+ return llvm::isa<FloatType>(value.getType());
+ }
+ bool isInteger(Value value) {
+ return llvm::isa<IntegerType>(value.getType());
+ }
+
+ OpBuilder &builder;
+ Block █
+};
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_LINALG_REGION_BUILDER_HELPER_H
>From 0b6215874a68cd8e6843e4e5203d3e0c889e6f59 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Mon, 20 Oct 2025 11:57:33 -0400
Subject: [PATCH 2/2] clang format.
---
mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
index de0edac8a6091..646650d1d9f83 100644
--- a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
+++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h
@@ -29,10 +29,10 @@
#ifndef MLIR_LINALG_REGION_BUILDER_HELPER_H
#define MLIR_LINALG_REGION_BUILDER_HELPER_H
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -88,7 +88,7 @@ class RegionBuilderHelper {
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
function_ref<InFlightDiagnostic()> emitError = {});
- // Create a `yieldOp` to yield `values` passed in arg.
+ // Create a `yieldOp` to yield `values` passed in arg.
void yieldOutputs(ValueRange values);
// Create a constant op with value parsed from string `value`.
More information about the Mlir-commits
mailing list