[Mlir-commits] [llvm] [mlir] [MLIR][IR] Add ConstantLikeInterface for extensible constant creation (PR #177740)

Ryan Kim llvmlistbot at llvm.org
Sat Jan 24 01:03:19 PST 2026


https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/177740

>From 16e367ed96648450d69cc94b5632d4f15ef7773d Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Fri, 23 Jan 2026 08:53:45 +0900
Subject: [PATCH] [MLIR][IR] Add ConstantLikeInterface for extensible constant
 creation

Introduces a new type interface `ConstantLikeInterface` that allows custom
types to define their own constant creation logic. This is particularly
useful for domain-specific types (e.g., finite field elements, custom
numeric types) that need special handling for constant attributes and ops.

Interface methods:
- `createConstantAttr(int64_t)`: Create a TypedAttr from a scalar value
- `createConstantAttrFromValues(ArrayRef<APInt>)`: Create a TypedAttr from
  multiple APInt values (useful for composite types like extension fields)
- `createConstantOp(OpBuilder&, Location, TypedAttr)`: Create a dialect-
  specific constant operation instead of relying on arith.constant
- `overrideShapedType(ShapedType)`: Override the shaped type used for
  tensor/vector constants (e.g., to use a storage type for the elements)

Integration points:
- `Builder::getZeroAttr` and `Builder::getOneAttr` now check for this
  interface first, enabling custom types to provide their own constant
  creation logic before falling back to built-in type handling.
- ElementwiseOpFusion uses this interface to create scalar constants when
  fusing splat operands into linalg.generic operations.
---
 mlir/include/mlir/IR/CMakeLists.txt           |   5 +
 mlir/include/mlir/IR/ConstantLikeInterface.h  |  34 ++
 mlir/include/mlir/IR/ConstantLikeInterface.td |  82 +++
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  29 +-
 mlir/lib/IR/Builders.cpp                      |  35 +-
 mlir/lib/IR/CMakeLists.txt                    |   2 +
 mlir/lib/IR/ConstantLikeInterface.cpp         |  15 +
 mlir/unittests/IR/CMakeLists.txt              |   5 +-
 .../IR/ConstantLikeInterfaceTest.cpp          | 516 ++++++++++++++++++
 utils/bazel/MODULE.bazel.lock                 |   5 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  19 +
 .../mlir/unittests/BUILD.bazel                |   1 +
 12 files changed, 737 insertions(+), 11 deletions(-)
 create mode 100644 mlir/include/mlir/IR/ConstantLikeInterface.h
 create mode 100644 mlir/include/mlir/IR/ConstantLikeInterface.td
 create mode 100644 mlir/lib/IR/ConstantLikeInterface.cpp
 create mode 100644 mlir/unittests/IR/ConstantLikeInterfaceTest.cpp

diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 0b3079cde568d..d7d1f116987a4 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -55,6 +55,11 @@ mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_mlir_generic_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS ConstantLikeInterface.td)
+mlir_tablegen(ConstantLikeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(ConstantLikeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRConstantLikeInterfaceIncGen)
+
 set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
 mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
 mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/ConstantLikeInterface.h b/mlir/include/mlir/IR/ConstantLikeInterface.h
new file mode 100644
index 0000000000000..e8d4a2794d2c1
--- /dev/null
+++ b/mlir/include/mlir/IR/ConstantLikeInterface.h
@@ -0,0 +1,34 @@
+//===- ConstantLikeInterface.h - Constant Creation Interface -------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for type interfaces that allow types to
+// define how to create constant attributes and operations for themselves.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_CONSTANTLIKEINTERFACE_H
+#define MLIR_IR_CONSTANTLIKEINTERFACE_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Types.h"
+
+namespace llvm {
+class APInt;
+} // namespace llvm
+
+namespace mlir {
+class TypedAttr;
+class OpBuilder;
+class Location;
+class Operation;
+} // namespace mlir
+
+#include "mlir/IR/ConstantLikeInterface.h.inc"
+
+#endif // MLIR_IR_CONSTANTLIKEINTERFACE_H
diff --git a/mlir/include/mlir/IR/ConstantLikeInterface.td b/mlir/include/mlir/IR/ConstantLikeInterface.td
new file mode 100644
index 0000000000000..af5d948bc7260
--- /dev/null
+++ b/mlir/include/mlir/IR/ConstantLikeInterface.td
@@ -0,0 +1,82 @@
+//===- ConstantLikeInterface.td - Constant creation interfaces -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for type interfaces that allow types to
+// define how to create constant attributes and operations for themselves.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_CONSTANTLIKEINTERFACE_TD_
+#define MLIR_IR_CONSTANTLIKEINTERFACE_TD_
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ConstantLikeInterface
+//===----------------------------------------------------------------------===//
+
+def ConstantLikeInterface : TypeInterface<"ConstantLikeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface allows types to define how to create constant attributes
+    and operations for their values. This decouples generic MLIR code from
+    specific constant operation types, enabling better layering and extensibility.
+
+    Types implementing this interface can provide custom constant creation logic,
+    which is particularly useful for domain-specific types (e.g., field elements,
+    custom numeric types) that need special constant handling.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Creates a constant attribute for this type from the given int64 value.
+        Returns null if the type does not support this operation.
+      }],
+      /*retTy=*/"::mlir::TypedAttr",
+      /*methodName=*/"createConstantAttr",
+      /*args=*/(ins "int64_t ":$value)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Creates a constant attribute for this type from the given APInt values.
+        The APInt should be compatible with this type's bit width and semantics.
+        Returns null if the type does not support this operation.
+      }],
+      /*retTy=*/"::mlir::TypedAttr",
+      /*methodName=*/"createConstantAttrFromValues",
+      /*args=*/(ins "::llvm::ArrayRef<APInt> ":$values)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Creates a constant operation for this type with the given attribute.
+        The builder's insertion point should be set before calling this method.
+        Returns null if the type does not support constant operation creation.
+      }],
+      /*retTy=*/"::mlir::Operation *",
+      /*methodName=*/"createConstantOp",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                    "::mlir::Location":$loc,
+                    "::mlir::TypedAttr":$attr)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Allows custom types to override the shaped type (tensor/vector) used for
+        constant creation. This is useful when a custom element type needs to
+        substitute the shaped type with a different representation (e.g., using
+        a storage type instead of the semantic type).
+        Returns the potentially modified shaped type.
+      }],
+      /*retTy=*/"::mlir::ShapedType",
+      /*methodName=*/"overrideShapedType",
+      /*args=*/(ins "::mlir::ShapedType":$shapedType)
+    >
+  ];
+}
+
+#endif // MLIR_IR_CONSTANTLIKEINTERFACE_TD_
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 72acd02d0d13d..cff08755052ca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/ConstantLikeInterface.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
@@ -2307,8 +2308,27 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
       }
 
       // Create a constant scalar value from the splat constant.
-      Value scalarConstant =
-          arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+      Value scalarConstant;
+      Region &region = genericOp->getRegion(0);
+      Block &entryBlock = *region.begin();
+      Value argument = entryBlock.getArgument(opOperand->getOperandNumber());
+      Type argType = argument.getType();
+
+      // Try to use ConstantLikeInterface for custom constant creation
+      if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(argType)) {
+        if (Operation *constOp = constType.createConstantOp(
+                rewriter, def->getLoc(), constantAttr)) {
+          scalarConstant = constOp->getResult(0);
+        } else {
+          // Fallback to arith::ConstantOp
+          scalarConstant =
+              arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+        }
+      } else {
+        // Default to arith::ConstantOp for types without the interface
+        scalarConstant =
+            arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+      }
 
       SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp =
@@ -2323,11 +2343,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
 
       // Map the block argument corresponding to the replaced argument with the
       // scalar constant.
-      Region &region = genericOp->getRegion(0);
-      Block &entryBlock = *region.begin();
       IRMapping mapping;
-      mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
-                  scalarConstant);
+      mapping.map(argument, scalarConstant);
       Region &fusedRegion = fusedOp->getRegion(0);
       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
                                  mapping);
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8f199b60fccdc..d7b2d96c71a3c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ConstantLikeInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -322,6 +323,14 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 TypedAttr Builder::getZeroAttr(Type type) {
+  // Check if the type implements ConstantLikeInterface for custom constant
+  // creation
+  if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(type)) {
+    if (auto attr = constType.createConstantAttr(0))
+      return attr;
+  }
+
+  // Fallback to built-in type handling
   if (llvm::isa<FloatType>(type))
     return getFloatAttr(type, 0.0);
   if (llvm::isa<IndexType>(type))
@@ -331,15 +340,30 @@ TypedAttr Builder::getZeroAttr(Type type) {
                           APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getZeroAttr(vtType.getElementType());
+    auto elementType = vtType.getElementType();
+
+    auto element = getZeroAttr(elementType);
     if (!element)
       return {};
+    // Check if element type implements ConstantLikeInterface for custom
+    // shaped constant creation
+    if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(elementType)) {
+      vtType = constType.overrideShapedType(vtType);
+    }
     return DenseElementsAttr::get(vtType, element);
   }
   return {};
 }
 
 TypedAttr Builder::getOneAttr(Type type) {
+  // Check if the type implements ConstantLikeInterface for custom constant
+  // creation
+  if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(type)) {
+    if (auto attr = constType.createConstantAttr(1))
+      return attr;
+  }
+
+  // Fallback to built-in type handling
   if (llvm::isa<FloatType>(type))
     return getFloatAttr(type, 1.0);
   if (llvm::isa<IndexType>(type))
@@ -349,9 +373,16 @@ TypedAttr Builder::getOneAttr(Type type) {
                           APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getOneAttr(vtType.getElementType());
+    auto elementType = vtType.getElementType();
+
+    auto element = getOneAttr(elementType);
     if (!element)
       return {};
+    // Check if element type implements ConstantLikeInterface for custom
+    // shaped constant creation
+    if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(elementType)) {
+      vtType = constType.overrideShapedType(vtType);
+    }
     return DenseElementsAttr::get(vtType, element);
   }
   return {};
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index d95bdc957e3c2..cf12cb5e53299 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -18,6 +18,7 @@ add_mlir_library(MLIRIR
   BuiltinDialectBytecode.cpp
   BuiltinTypes.cpp
   BuiltinTypeInterfaces.cpp
+  ConstantLikeInterface.cpp
   Diagnostics.cpp
   Dialect.cpp
   DialectResourceBlobManager.cpp
@@ -61,6 +62,7 @@ add_mlir_library(MLIRIR
   MLIRBuiltinTypeInterfacesIncGen
   MLIRCallInterfacesIncGen
   MLIRCastInterfacesIncGen
+  MLIRConstantLikeInterfaceIncGen
   MLIRDataLayoutInterfacesIncGen
   MLIROpAsmInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
diff --git a/mlir/lib/IR/ConstantLikeInterface.cpp b/mlir/lib/IR/ConstantLikeInterface.cpp
new file mode 100644
index 0000000000000..028434eea5771
--- /dev/null
+++ b/mlir/lib/IR/ConstantLikeInterface.cpp
@@ -0,0 +1,15 @@
+//===- ConstantLikeInterface.cpp - Constant Creation Interface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ConstantLikeInterface.h"
+
+namespace mlir {
+
+#include "mlir/IR/ConstantLikeInterface.cpp.inc"
+
+} // namespace mlir
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index dd3b110dcd295..af04c8b5fd9ed 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_unittest(MLIRIRTests
   AffineMapTest.cpp
   AttributeTest.cpp
   AttrTypeReplacerTest.cpp
+  ConstantLikeInterfaceTest.cpp
   Diagnostic.cpp
   DialectTest.cpp
   DistinctAttributeAllocatorTest.cpp
@@ -14,7 +15,7 @@ add_mlir_unittest(MLIRIRTests
   MemrefLayoutTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp
-  RemarkTest.cpp  
+  RemarkTest.cpp
   ShapedTypeTest.cpp
   SymbolTableTest.cpp
   TypeTest.cpp
@@ -24,9 +25,11 @@ add_mlir_unittest(MLIRIRTests
   BlobManagerTest.cpp
 
   DEPENDS
+  MLIRConstantLikeInterfaceIncGen
   MLIRTestInterfaceIncGen
 )
 target_include_directories(MLIRIRTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
 mlir_target_link_libraries(MLIRIRTests PRIVATE MLIRIR)
+target_link_libraries(MLIRIRTests PRIVATE MLIRArithDialect)
 target_link_libraries(MLIRIRTests PRIVATE MLIRTestDialect)
 target_link_libraries(MLIRIRTests PRIVATE MLIRRemarkStreamer)
diff --git a/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp b/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp
new file mode 100644
index 0000000000000..d956e1fdb5e2d
--- /dev/null
+++ b/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp
@@ -0,0 +1,516 @@
+//===- ConstantLikeInterfaceTest.cpp - Test ConstantLikeInterface ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements tests for the ConstantLikeInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ConstantLikeInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test External Models
+//===----------------------------------------------------------------------===//
+
+/// External interface model for IntegerType that provides custom constant
+/// creation. This model multiplies values by 2 to demonstrate the interface
+/// is being used.
+struct TestConstantLikeIntegerModel
+    : public ConstantLikeInterface::ExternalModel<TestConstantLikeIntegerModel,
+                                                  IntegerType> {
+  /// Creates an IntegerAttr with the value multiplied by 2 (for testing).
+  TypedAttr createConstantAttr(Type type, int64_t value) const {
+    auto intType = llvm::cast<IntegerType>(type);
+    // Multiply by 2 to verify this method is called instead of the default.
+    return IntegerAttr::get(intType, value * 2);
+  }
+
+  /// Creates an IntegerAttr from APInt values.
+  /// For IntegerType, we expect exactly one APInt value and add 10 to it.
+  TypedAttr createConstantAttrFromValues(Type type,
+                                         ArrayRef<APInt> values) const {
+    if (values.size() != 1)
+      return {};
+    auto intType = llvm::cast<IntegerType>(type);
+    // Add 10 to verify this method is called.
+    APInt result = values[0] + 10;
+    return IntegerAttr::get(intType, result);
+  }
+
+  /// Creates an arith.constant operation for the integer type.
+  Operation *createConstantOp(Type type, OpBuilder &builder, Location loc,
+                              TypedAttr attr) const {
+    return arith::ConstantOp::create(builder, loc, attr);
+  }
+
+  /// Overrides the shaped type by changing the element type to i64.
+  /// This simulates types that need a different storage representation.
+  ShapedType overrideShapedType(Type type, ShapedType shapedType) const {
+    // Override to use i64 as storage type for demonstration.
+    IntegerType storageType = IntegerType::get(type.getContext(), 64);
+    return shapedType.clone(storageType);
+  }
+};
+
+/// External interface model for Float32Type with distinct behavior.
+struct TestConstantLikeFloatModel
+    : public ConstantLikeInterface::ExternalModel<TestConstantLikeFloatModel,
+                                                  Float32Type> {
+  /// Creates a FloatAttr with the value plus 100 (for testing).
+  TypedAttr createConstantAttr(Type type, int64_t value) const {
+    // Add 100 to verify this method is called.
+    return FloatAttr::get(type, static_cast<double>(value) + 100.0);
+  }
+
+  TypedAttr createConstantAttrFromValues(Type type,
+                                         ArrayRef<APInt> values) const {
+    if (values.size() != 1)
+      return {};
+    // Convert APInt to double and add 50.
+    double d = static_cast<double>(values[0].getSExtValue()) + 50.0;
+    return FloatAttr::get(type, d);
+  }
+
+  Operation *createConstantOp(Type type, OpBuilder &builder, Location loc,
+                              TypedAttr attr) const {
+    return arith::ConstantOp::create(builder, loc, attr);
+  }
+
+  ShapedType overrideShapedType(Type type, ShapedType shapedType) const {
+    // Keep unchanged for float.
+    return shapedType;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Basic Interface Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, InterfaceNotAttached) {
+  // Without attaching the interface, types should not implement it.
+  MLIRContext context;
+  IntegerType i32 = IntegerType::get(&context, 32);
+  EXPECT_FALSE(isa<ConstantLikeInterface>(i32));
+
+  Float32Type f32 = Float32Type::get(&context);
+  EXPECT_FALSE(isa<ConstantLikeInterface>(f32));
+}
+
+TEST(ConstantLikeInterfaceTest, AttachInterfaceToIntegerType) {
+  MLIRContext context;
+
+  // Attach the interface to IntegerType.
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = dyn_cast<ConstantLikeInterface>(i32);
+  ASSERT_TRUE(iface != nullptr);
+
+  // Test createConstantAttr - our model multiplies by 2.
+  TypedAttr attr = iface.createConstantAttr(5);
+  ASSERT_TRUE(attr != nullptr);
+  auto intAttr = dyn_cast<IntegerAttr>(attr);
+  ASSERT_TRUE(intAttr != nullptr);
+  EXPECT_EQ(intAttr.getInt(), 10); // 5 * 2 = 10
+}
+
+TEST(ConstantLikeInterfaceTest, AttachInterfaceToFloatType) {
+  MLIRContext context;
+
+  // Attach the interface to Float32Type.
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  Float32Type f32 = Float32Type::get(&context);
+  auto iface = dyn_cast<ConstantLikeInterface>(f32);
+  ASSERT_TRUE(iface != nullptr);
+
+  // Test createConstantAttr - our model adds 100.
+  TypedAttr attr = iface.createConstantAttr(1);
+  ASSERT_TRUE(attr != nullptr);
+  auto floatAttr = dyn_cast<FloatAttr>(attr);
+  ASSERT_TRUE(floatAttr != nullptr);
+  EXPECT_DOUBLE_EQ(floatAttr.getValueAsDouble(), 101.0); // 1 + 100 = 101
+}
+
+TEST(ConstantLikeInterfaceTest, InterfaceNotSharedAcrossContexts) {
+  MLIRContext context1;
+  MLIRContext context2;
+
+  // Attach interface only to context1.
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context1);
+
+  IntegerType i32Ctx1 = IntegerType::get(&context1, 32);
+  IntegerType i32Ctx2 = IntegerType::get(&context2, 32);
+
+  // context1 should have the interface.
+  EXPECT_TRUE(isa<ConstantLikeInterface>(i32Ctx1));
+
+  // context2 should NOT have the interface.
+  EXPECT_FALSE(isa<ConstantLikeInterface>(i32Ctx2));
+}
+
+//===----------------------------------------------------------------------===//
+// createConstantAttrFromValues Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValues) {
+  MLIRContext context;
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = cast<ConstantLikeInterface>(i32);
+
+  // Test createConstantAttrFromValues with a single APInt.
+  APInt value(32, 7); // 7
+  SmallVector<APInt> values = {value};
+  TypedAttr attr = iface.createConstantAttrFromValues(values);
+  ASSERT_TRUE(attr != nullptr);
+  auto intAttr = dyn_cast<IntegerAttr>(attr);
+  ASSERT_TRUE(intAttr != nullptr);
+  EXPECT_EQ(intAttr.getInt(), 17); // 7 + 10 = 17
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValuesFloat) {
+  MLIRContext context;
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  Float32Type f32 = Float32Type::get(&context);
+  auto iface = cast<ConstantLikeInterface>(f32);
+
+  // Test createConstantAttrFromValues for float.
+  APInt value(32, 25);
+  SmallVector<APInt> values = {value};
+  TypedAttr attr = iface.createConstantAttrFromValues(values);
+  ASSERT_TRUE(attr != nullptr);
+  auto floatAttr = dyn_cast<FloatAttr>(attr);
+  ASSERT_TRUE(floatAttr != nullptr);
+  EXPECT_DOUBLE_EQ(floatAttr.getValueAsDouble(), 75.0); // 25 + 50 = 75
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValuesInvalidSize) {
+  MLIRContext context;
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = cast<ConstantLikeInterface>(i32);
+
+  // Test with wrong number of values (our model expects exactly 1).
+  APInt v1(32, 1), v2(32, 2);
+  SmallVector<APInt> values = {v1, v2};
+  TypedAttr attr = iface.createConstantAttrFromValues(values);
+  EXPECT_TRUE(attr == nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// createConstantOp Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, CreateConstantOp) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  // Create a module to hold operations.
+  OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module->getBody(), module->getBody()->begin());
+  Location loc = builder.getUnknownLoc();
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = cast<ConstantLikeInterface>(i32);
+
+  // Create an attribute and then a constant operation.
+  TypedAttr attr = iface.createConstantAttr(21);
+  ASSERT_TRUE(attr != nullptr);
+
+  Operation *constOp = iface.createConstantOp(builder, loc, attr);
+  ASSERT_TRUE(constOp != nullptr);
+
+  // Verify it's an arith.constant.
+  EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+  auto arithConst = cast<arith::ConstantOp>(constOp);
+  auto resultAttr = dyn_cast<IntegerAttr>(arithConst.getValue());
+  ASSERT_TRUE(resultAttr != nullptr);
+  EXPECT_EQ(resultAttr.getInt(), 42); // 21 * 2 = 42
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantOpFloat) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module->getBody(), module->getBody()->begin());
+  Location loc = builder.getUnknownLoc();
+
+  Float32Type f32 = Float32Type::get(&context);
+  auto iface = cast<ConstantLikeInterface>(f32);
+
+  TypedAttr attr = iface.createConstantAttr(50);
+  Operation *constOp = iface.createConstantOp(builder, loc, attr);
+  ASSERT_TRUE(constOp != nullptr);
+
+  EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+  auto arithConst = cast<arith::ConstantOp>(constOp);
+  auto resultAttr = dyn_cast<FloatAttr>(arithConst.getValue());
+  ASSERT_TRUE(resultAttr != nullptr);
+  EXPECT_DOUBLE_EQ(resultAttr.getValueAsDouble(), 150.0); // 50 + 100 = 150
+}
+
+//===----------------------------------------------------------------------===//
+// overrideShapedType Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedType) {
+  MLIRContext context;
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = cast<ConstantLikeInterface>(i32);
+
+  // Create a tensor type with i32 elements.
+  RankedTensorType tensorType = RankedTensorType::get({2, 3}, i32);
+
+  // Override should change element type to i64.
+  ShapedType overridden = iface.overrideShapedType(tensorType);
+  ASSERT_TRUE(overridden != nullptr);
+
+  // Shape should be preserved.
+  EXPECT_EQ(overridden.getShape(), tensorType.getShape());
+
+  // Element type should now be i64.
+  auto elementType = dyn_cast<IntegerType>(overridden.getElementType());
+  ASSERT_TRUE(elementType != nullptr);
+  EXPECT_EQ(elementType.getWidth(), 64u);
+}
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedTypePreservesShape) {
+  MLIRContext context;
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  IntegerType i16 = IntegerType::get(&context, 16);
+  auto iface = cast<ConstantLikeInterface>(i16);
+
+  // Test with various shapes.
+  RankedTensorType tensor1D = RankedTensorType::get({10}, i16);
+  RankedTensorType tensor3D = RankedTensorType::get({2, 3, 4}, i16);
+
+  ShapedType overridden1D = iface.overrideShapedType(tensor1D);
+  ShapedType overridden3D = iface.overrideShapedType(tensor3D);
+
+  EXPECT_EQ(overridden1D.getShape(), tensor1D.getShape());
+  EXPECT_EQ(overridden3D.getShape(), tensor3D.getShape());
+
+  // Both should have i64 element type.
+  EXPECT_EQ(cast<IntegerType>(overridden1D.getElementType()).getWidth(), 64u);
+  EXPECT_EQ(cast<IntegerType>(overridden3D.getElementType()).getWidth(), 64u);
+}
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedTypeFloat) {
+  MLIRContext context;
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  Float32Type f32 = Float32Type::get(&context);
+  auto iface = cast<ConstantLikeInterface>(f32);
+
+  // Create a tensor type with f32 elements.
+  RankedTensorType tensorType = RankedTensorType::get({4, 5}, f32);
+
+  // Float model keeps type unchanged.
+  ShapedType overridden = iface.overrideShapedType(tensorType);
+  EXPECT_EQ(overridden, tensorType);
+}
+
+//===----------------------------------------------------------------------===//
+// Builder Integration Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, BuilderGetZeroAttrUsesInterface) {
+  MLIRContext context;
+
+  // First, test without interface - should use default behavior.
+  Builder builderWithoutInterface(&context);
+  IntegerType i32 = IntegerType::get(&context, 32);
+
+  TypedAttr defaultZero = builderWithoutInterface.getZeroAttr(i32);
+  ASSERT_TRUE(defaultZero != nullptr);
+  auto defaultIntAttr = dyn_cast<IntegerAttr>(defaultZero);
+  ASSERT_TRUE(defaultIntAttr != nullptr);
+  EXPECT_EQ(defaultIntAttr.getInt(), 0); // Default: just 0
+
+  // Now attach the interface.
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  // After attaching, getZeroAttr should use our interface.
+  Builder builderWithInterface(&context);
+  TypedAttr customZero = builderWithInterface.getZeroAttr(i32);
+  ASSERT_TRUE(customZero != nullptr);
+  auto customIntAttr = dyn_cast<IntegerAttr>(customZero);
+  ASSERT_TRUE(customIntAttr != nullptr);
+  EXPECT_EQ(customIntAttr.getInt(),
+            0); // 0 * 2 = 0 (still 0, but via interface)
+}
+
+TEST(ConstantLikeInterfaceTest, BuilderGetOneAttrUsesInterface) {
+  MLIRContext context;
+
+  // Attach the interface to IntegerType.
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  Builder builder(&context);
+  IntegerType i32 = IntegerType::get(&context, 32);
+
+  // getOneAttr should use our interface which multiplies by 2.
+  TypedAttr customOne = builder.getOneAttr(i32);
+  ASSERT_TRUE(customOne != nullptr);
+  auto customIntAttr = dyn_cast<IntegerAttr>(customOne);
+  ASSERT_TRUE(customIntAttr != nullptr);
+  EXPECT_EQ(customIntAttr.getInt(), 2); // 1 * 2 = 2
+}
+
+TEST(ConstantLikeInterfaceTest, ShapedTypeWithInterface) {
+  MLIRContext context;
+
+  // Use Float32Type because the float model doesn't override shaped type,
+  // which avoids element type mismatch in DenseElementsAttr creation.
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  Builder builder(&context);
+  Float32Type f32 = Float32Type::get(&context);
+  RankedTensorType tensorType = RankedTensorType::get({2, 3}, f32);
+
+  // getZeroAttr for tensor should create DenseElementsAttr.
+  TypedAttr zeroTensor = builder.getZeroAttr(tensorType);
+  ASSERT_TRUE(zeroTensor != nullptr);
+  auto denseAttr = dyn_cast<DenseElementsAttr>(zeroTensor);
+  ASSERT_TRUE(denseAttr != nullptr);
+
+  // Check that all elements are 100.0 (0 + 100 = 100 from our float model).
+  for (auto val : denseAttr.getValues<APFloat>()) {
+    EXPECT_DOUBLE_EQ(val.convertToDouble(), 100.0);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Delayed Registration Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, DelayedRegistration) {
+  DialectRegistry registry;
+  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+    IntegerType::attachInterface<TestConstantLikeIntegerModel>(*ctx);
+  });
+
+  MLIRContext context(registry);
+  IntegerType i32 = IntegerType::get(&context, 32);
+
+  // Interface should be available.
+  EXPECT_TRUE(isa<ConstantLikeInterface>(i32));
+
+  // Test that the interface works correctly.
+  auto iface = cast<ConstantLikeInterface>(i32);
+  TypedAttr attr = iface.createConstantAttr(7);
+  auto intAttr = cast<IntegerAttr>(attr);
+  EXPECT_EQ(intAttr.getInt(), 14); // 7 * 2 = 14
+}
+
+//===----------------------------------------------------------------------===//
+// Combined Workflow Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, FullWorkflowWithAllMethods) {
+  // Test a complete workflow using all interface methods.
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+  OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module->getBody(), module->getBody()->begin());
+  Location loc = builder.getUnknownLoc();
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  auto iface = cast<ConstantLikeInterface>(i32);
+
+  // Step 1: Create attr from int64_t.
+  TypedAttr attrFromInt = iface.createConstantAttr(5);
+  EXPECT_EQ(cast<IntegerAttr>(attrFromInt).getInt(), 10); // 5 * 2
+
+  // Step 2: Create attr from APInt values.
+  APInt apVal(32, 20);
+  TypedAttr attrFromAPInt = iface.createConstantAttrFromValues({apVal});
+  EXPECT_EQ(cast<IntegerAttr>(attrFromAPInt).getInt(), 30); // 20 + 10
+
+  // Step 3: Override shaped type.
+  RankedTensorType tensorType = RankedTensorType::get({4}, i32);
+  ShapedType overridden = iface.overrideShapedType(tensorType);
+  EXPECT_EQ(cast<IntegerType>(overridden.getElementType()).getWidth(), 64u);
+
+  // Step 4: Create constant op.
+  Operation *constOp = iface.createConstantOp(builder, loc, attrFromInt);
+  ASSERT_TRUE(constOp != nullptr);
+  EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+
+  // Verify the operation is in the module.
+  EXPECT_FALSE(module->getBody()->empty());
+}
+
+TEST(ConstantLikeInterfaceTest, MultipleTypesWithDifferentBehavior) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+  Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+  OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module->getBody(), module->getBody()->begin());
+  Location loc = builder.getUnknownLoc();
+
+  IntegerType i32 = IntegerType::get(&context, 32);
+  Float32Type f32 = Float32Type::get(&context);
+
+  auto intIface = cast<ConstantLikeInterface>(i32);
+  auto floatIface = cast<ConstantLikeInterface>(f32);
+
+  // Both types should produce different results for the same input.
+  TypedAttr intAttr = intIface.createConstantAttr(10);
+  TypedAttr floatAttr = floatIface.createConstantAttr(10);
+
+  EXPECT_EQ(cast<IntegerAttr>(intAttr).getInt(), 20); // 10 * 2
+  EXPECT_DOUBLE_EQ(cast<FloatAttr>(floatAttr).getValueAsDouble(),
+                   110.0); // 10 + 100
+
+  // Create constant ops for both.
+  Operation *intConstOp = intIface.createConstantOp(builder, loc, intAttr);
+  Operation *floatConstOp =
+      floatIface.createConstantOp(builder, loc, floatAttr);
+
+  ASSERT_TRUE(intConstOp != nullptr);
+  ASSERT_TRUE(floatConstOp != nullptr);
+
+  // Both should be arith.constant but with different types.
+  EXPECT_TRUE(isa<arith::ConstantOp>(intConstOp));
+  EXPECT_TRUE(isa<arith::ConstantOp>(floatConstOp));
+
+  auto intResult = cast<arith::ConstantOp>(intConstOp).getResult().getType();
+  auto floatResult =
+      cast<arith::ConstantOp>(floatConstOp).getResult().getType();
+
+  EXPECT_TRUE(isa<IntegerType>(intResult));
+  EXPECT_TRUE(isa<Float32Type>(floatResult));
+}
+
+} // namespace
diff --git a/utils/bazel/MODULE.bazel.lock b/utils/bazel/MODULE.bazel.lock
index 7070fd2fa9c61..b01e8474f70e0 100644
--- a/utils/bazel/MODULE.bazel.lock
+++ b/utils/bazel/MODULE.bazel.lock
@@ -234,7 +234,7 @@
   "moduleExtensions": {
     "//:extensions.bzl%llvm_repos_extension": {
       "general": {
-        "bzlTransitiveDigest": "9jGazpNxASw0pQCCKAMsxGYnVBJH8Mkddp3w7yRm6eU=",
+        "bzlTransitiveDigest": "DqyjXvFpPyfsmyt58EuR0v0Xy+nPN74iTGxjOW99OcQ=",
         "usagesDigest": "X0yUkkWyxQ2Y5oZVDkRSE/K4YkDWo1IjhHsL+1weKyU=",
         "recordedFileInputs": {},
         "recordedDirentsInputs": {},
@@ -300,7 +300,8 @@
             "repoRuleId": "@@bazel_tools//tools/build_defs/repo:http.bzl%http_archive",
             "attributes": {
               "urls": [
-                "https://versaweb.dl.sourceforge.net/project/perfmon2/libpfm4/libpfm-4.13.0.tar.gz"
+                "https://versaweb.dl.sourceforge.net/project/perfmon2/libpfm4/libpfm-4.13.0.tar.gz",
+                "https://sourceforge.net/projects/perfmon2/files/libpfm4/libpfm-4.13.0.tar.gz"
               ],
               "sha256": "d18b97764c755528c1051d376e33545d0eb60c6ebf85680436813fa5b04cc3d1",
               "strip_prefix": "libpfm-4.13.0",
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 73adfa40d831f..b10caecc749f6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -113,6 +113,17 @@ gentbl_cc_library(
     deps = [":OpBaseTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "ConstantLikeInterfaceIncGen",
+    tbl_outs = {
+        "include/mlir/IR/ConstantLikeInterface.h.inc": ["-gen-type-interface-decls"],
+        "include/mlir/IR/ConstantLikeInterface.cpp.inc": ["-gen-type-interface-defs"],
+    },
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/IR/ConstantLikeInterface.td",
+    deps = [":OpBaseTdFiles"],
+)
+
 gentbl_cc_library(
     name = "TensorEncodingIncGen",
     tbl_outs = {
@@ -250,6 +261,13 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "ConstantLikeInterfaceTdFiles",
+    srcs = ["include/mlir/IR/ConstantLikeInterface.td"],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"],
+)
+
 td_library(
     name = "FunctionInterfacesTdFiles",
     srcs = ["include/mlir/Interfaces/FunctionInterfaces.td"],
@@ -421,6 +439,7 @@ cc_library(
         ":BuiltinTypesIncGen",
         ":BytecodeOpInterfaceIncGen",
         ":CallOpInterfacesIncGen",
+        ":ConstantLikeInterfaceIncGen",
         ":DataLayoutInterfacesIncGen",
         ":InferIntRangeInterfaceIncGen",
         ":OpAsmInterfaceIncGen",
diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
index dfb5be5f85027..37277fdcb0699 100644
--- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
@@ -40,6 +40,7 @@ cc_test(
         "//llvm:Core",
         "//llvm:Remarks",
         "//llvm:Support",
+        "//mlir:ArithDialect",
         "//mlir:BytecodeReader",
         "//mlir:CallOpInterfaces",
         "//mlir:FunctionInterfaces",



More information about the Mlir-commits mailing list