[Mlir-commits] [mlir] [mlir][linalg][nfc] Code refactoring of LinalgOps.cpp (PR #164274)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 20 08:54:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

<details>
<summary>Changes</summary>

The RegionBuilder has grown quite big with declaration and definitions all in one blob.
 And as the TODO  in the existing code rightly suggests, it would be cleaner to extract it out.

---

Patch is 28.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164274.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-291) 
- (added) mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp (+237) 
- (added) mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h (+148) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index ec433284e17ad..75f65f39f2371 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   LinalgOps.cpp
   LinalgDialect.cpp
   ValueBoundsOpInterfaceImpl.cpp
+  RegionBuilderHelper.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b0c8cbd..f029613a280e1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -9,7 +9,6 @@
 // This file implements the Linalg operations.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 
 #include "mlir/AsmParser/AsmParser.h"
@@ -50,6 +49,8 @@
 #include <cassert>
 #include <optional>
 
+#include "RegionBuilderHelper.h"
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -411,296 +412,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
   // Region is elided.
 }
 
-//===----------------------------------------------------------------------===//
-// Region builder helper.
-// TODO: Move this to a utility library.
-// The public methods on this class are referenced directly from generated code.
-// Helper build the unary, binary, and type conversion functions defined by the
-// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
-// class.
-//
-// Implementations of the math functions must be polymorphic over numeric types,
-// internally performing necessary casts. If the function application makes no
-// sense, then the only recourse is to assert and return nullptr. This can be
-// extended later if it becomes possible to fail construction of the region. The
-// invariant should be enforced at a higher level.
-//
-// TODO: These helpers are currently type polymorphic over the class of integer
-// and floating point types, but they will not internally cast within bit
-// widths of a class (mixed precision such as i8->i32) or across classes
-// (i.e. mixed float and integer). Many such combinations are ambiguous or need
-// to be handled with care and work is being considered to extend the op
-// language to make such cases explicit. In the mean-time, violating this will
-// fail verification, which is deemed acceptable.
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-class RegionBuilderHelper {
-public:
-  RegionBuilderHelper(OpBuilder &builder, Block &block)
-      : builder(builder), block(block) {}
-
-  // Build the unary functions defined by OpDSL.
-  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
-                     function_ref<InFlightDiagnostic()> emitError = {}) {
-    if (!isFloatingPoint(arg)) {
-      if (emitError) {
-        emitError() << "unsupported non numeric type";
-        return nullptr;
-      }
-      llvm_unreachable("unsupported non numeric type");
-    }
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    switch (unaryFn) {
-    case UnaryFn::exp:
-      return math::ExpOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::log:
-      return math::LogOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::abs:
-      return math::AbsFOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::ceil:
-      return math::CeilOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::floor:
-      return math::FloorOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::negf:
-      return arith::NegFOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::reciprocal: {
-      Attribute oneAttr = builder.getOneAttr(arg.getType());
-      auto one = arith::ConstantOp::create(builder, arg.getLoc(),
-                                           ::cast<TypedAttr>(oneAttr));
-      return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
-    }
-    case UnaryFn::round:
-      return math::RoundOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::sqrt:
-      return math::SqrtOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::rsqrt:
-      return math::RsqrtOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::square:
-      return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
-    case UnaryFn::tanh:
-      return math::TanhOp::create(builder, arg.getLoc(), arg);
-    case UnaryFn::erf:
-      return math::ErfOp::create(builder, arg.getLoc(), arg);
-    }
-    if (emitError) {
-      emitError() << "unsupported unary function";
-      return nullptr;
-    }
-    llvm_unreachable("unsupported unary function");
-  }
-
-  // Build the binary functions defined by OpDSL.
-  // If emitError is provided, an error will be emitted if the operation is not
-  // supported and a nullptr will be returned, otherwise an assertion will be
-  // raised.
-  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
-                      function_ref<InFlightDiagnostic()> emitError = {}) {
-    bool allComplex = isComplex(arg0) && isComplex(arg1);
-    bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
-    bool allInteger = isInteger(arg0) && isInteger(arg1);
-    bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
-                   arg1.getType().getIntOrFloatBitWidth() == 1;
-    if (!allComplex && !allFloatingPoint && !allInteger) {
-      if (emitError) {
-        emitError()
-            << "Cannot build binary Linalg operation: expects allComplex, "
-               "allFloatingPoint, or allInteger, got "
-            << arg0.getType() << " and " << arg1.getType();
-        return nullptr;
-      }
-      llvm_unreachable("unsupported non numeric type");
-    }
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    switch (binaryFn) {
-    case BinaryFn::add:
-      if (allComplex)
-        return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::sub:
-      if (allComplex)
-        return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allBool) {
-        if (emitError) {
-          emitError() << "unsupported operation: sub with bools";
-          return nullptr;
-        }
-        llvm_unreachable("unsupported operation: sub with bools");
-      }
-      return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::mul:
-      if (allComplex)
-        return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allBool)
-        return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::div:
-      if (allComplex)
-        return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allFloatingPoint)
-        return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      if (allBool) {
-        if (emitError) {
-          emitError() << "unsupported operation: div with bools";
-          return nullptr;
-        }
-        llvm_unreachable("unsupported operation: div with bools");
-      }
-      return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::div_unsigned:
-      if (!allInteger || allBool) {
-        if (emitError) {
-          emitError() << "unsupported operation: unsigned div not on uint";
-          return nullptr;
-        }
-        llvm_unreachable("unsupported operation: unsigned div not on uint");
-      }
-      return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::max_signed:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::min_signed:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::max_unsigned:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::min_unsigned:
-      assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
-      return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
-    case BinaryFn::powf:
-      assert(allFloatingPoint);
-      return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
-    }
-    if (emitError) {
-      emitError() << "unsupported binary function";
-      return nullptr;
-    }
-    llvm_unreachable("unsupported binary function");
-  }
-
-  // Build the ternary functions defined by OpDSL.
-  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
-                       function_ref<InFlightDiagnostic()> emitError = {}) {
-    bool headBool =
-        isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
-    bool tailFloatingPoint =
-        isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
-    bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    switch (ternaryFn) {
-    case TernaryFn::select:
-      if (!headBool && !(tailFloatingPoint || tailInteger))
-        llvm_unreachable("unsupported non numeric type");
-      return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
-    }
-    if (emitError) {
-      emitError() << "unsupported ternary function";
-      return nullptr;
-    }
-    llvm_unreachable("unsupported ternary function");
-  }
-
-  // Build the type functions defined by OpDSL.
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
-                    function_ref<InFlightDiagnostic()> emitError = {}) {
-    switch (typeFn) {
-    case TypeFn::cast_signed:
-      return cast(toType, operand, false);
-    case TypeFn::cast_unsigned:
-      return cast(toType, operand, true);
-    }
-    if (emitError) {
-      emitError() << "unsupported type conversion function";
-      return nullptr;
-    }
-    llvm_unreachable("unsupported type conversion function");
-  }
-
-  void yieldOutputs(ValueRange values) {
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    Location loc = builder.getUnknownLoc();
-    YieldOp::create(builder, loc, values);
-  }
-
-  Value constant(const std::string &value) {
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    Location loc = builder.getUnknownLoc();
-    Attribute valueAttr = parseAttribute(value, builder.getContext());
-    return arith::ConstantOp::create(builder, loc,
-                                     ::cast<TypedAttr>(valueAttr));
-  }
-
-  Value index(int64_t dim) {
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    return IndexOp::create(builder, builder.getUnknownLoc(), dim);
-  }
-
-  Type getIntegerType(unsigned width) {
-    return IntegerType::get(builder.getContext(), width);
-  }
-
-  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
-  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
-
-private:
-  // Generates operations to cast the given operand to a specified type.
-  // If the cast cannot be performed, a warning will be issued and the
-  // operand returned as-is (which will presumably yield a verification
-  // issue downstream).
-  Value cast(Type toType, Value operand, bool isUnsignedCast) {
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPointToEnd(&block);
-    auto loc = operand.getLoc();
-    if (isa<UnknownLoc>(loc)) {
-      if (operand.getDefiningOp())
-        loc = operand.getDefiningOp()->getLoc();
-      else if (operand.getParentBlock() &&
-               operand.getParentBlock()->getParentOp())
-        loc = operand.getParentBlock()->getParentOp()->getLoc();
-    }
-    return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
-  }
-
-  bool isComplex(Value value) {
-    return llvm::isa<ComplexType>(value.getType());
-  }
-  bool isFloatingPoint(Value value) {
-    return llvm::isa<FloatType>(value.getType());
-  }
-  bool isInteger(Value value) {
-    return llvm::isa<IntegerType>(value.getType());
-  }
-
-  OpBuilder &builder;
-  Block █
-};
-
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp
new file mode 100644
index 0000000000000..b1b6dd8afe8c5
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp
@@ -0,0 +1,237 @@
+//===- RegionBuilderHelper.cpp - Region Builder Helper class    -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of RegionBuilderHelper class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "RegionBuilderHelper.h"
+
+namespace mlir {
+namespace linalg {
+
+Value RegionBuilderHelper::buildUnaryFn(
+    UnaryFn unaryFn, Value arg, function_ref<InFlightDiagnostic()> emitError) {
+  if (!isFloatingPoint(arg)) {
+    if (emitError) {
+      emitError() << "unsupported non numeric type";
+      return nullptr;
+    }
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPointToEnd(&block);
+  switch (unaryFn) {
+  case UnaryFn::exp:
+    return math::ExpOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::log:
+    return math::LogOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::abs:
+    return math::AbsFOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::ceil:
+    return math::CeilOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::floor:
+    return math::FloorOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::negf:
+    return arith::NegFOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::reciprocal: {
+    Attribute oneAttr = builder.getOneAttr(arg.getType());
+    auto one = arith::ConstantOp::create(builder, arg.getLoc(),
+                                         llvm::cast<TypedAttr>(oneAttr));
+    return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
+  }  
+  case UnaryFn::round:
+    return math::RoundOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::sqrt:
+    return math::SqrtOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::rsqrt:
+    return math::RsqrtOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::square:
+    return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
+  case UnaryFn::tanh:
+    return math::TanhOp::create(builder, arg.getLoc(), arg);
+  case UnaryFn::erf:
+    return math::ErfOp::create(builder, arg.getLoc(), arg);
+  }
+  
+  if (emitError) {
+    emitError() << "unsupported unary function";
+    return nullptr;
+  }
+  llvm_unreachable("unsupported unary function");
+}
+
+Value RegionBuilderHelper::buildBinaryFn(
+    BinaryFn binaryFn, Value arg0, Value arg1,
+    function_ref<InFlightDiagnostic()> emitError) {
+
+  bool allComplex = isComplex(arg0) && isComplex(arg1);
+  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+  bool allInteger = isInteger(arg0) && isInteger(arg1);
+  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+                 arg1.getType().getIntOrFloatBitWidth() == 1;
+
+  if (!allComplex && !allFloatingPoint && !allInteger) {
+    if (emitError) {
+      emitError()
+          << "Cannot build binary Linalg operation: expects allComplex, "
+             "allFloatingPoint, or allInteger, got "
+          << arg0.getType() << " and " << arg1.getType();
+      return nullptr;
+    }
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPointToEnd(&block);
+  switch (binaryFn) {
+  case BinaryFn::add:
+    if (allComplex)
+      return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allFloatingPoint)
+      return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allBool)
+      return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::sub:
+    if (allComplex)
+      return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allFloatingPoint)
+      return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allBool) {
+      if (emitError) {
+        emitError() << "unsupported operation: sub with bools";
+        return nullptr;
+      }
+      llvm_unreachable("unsupported operation: sub with bools");
+    }
+    return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::mul:
+    if (allComplex)
+      return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allFloatingPoint)
+      return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allBool)
+      return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::div:
+    if (allComplex)
+      return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allFloatingPoint)
+      return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    if (allBool) {
+      if (emitError) {
+        emitError() << "unsupported operation: div with bools";
+        return nullptr;
+      }
+      llvm_unreachable("unsupported operation: div with bools");
+    }
+    return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::div_unsigned:
+    if (!allInteger || allBool) {
+      if (emitError) {
+        emitError() << "unsupported operation: unsigned div not on uint";
+        return nullptr;
+      }
+      llvm_unreachable("unsupported operation: unsigned div not on uint");
+    }
+    return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::max_signed:
+    assert(!allComplex);
+    if (allFloatingPoint)
+      return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::min_signed:
+    assert(!allComplex);
+    if (allFloatingPoint)
+      return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::max_unsigned:
+    assert(!allComplex);
+    if (allFloatingPoint)
+      return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::min_unsigned:
+    assert(!allComplex);
+    if (allFloatingPoint)
+      return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+    return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
+  case BinaryFn::powf:
+    assert(allFloatingPoint);
+    return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
+  }
+
+  if (emitError) {
+    emitError() << "unsup...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/164274


More information about the Mlir-commits mailing list