[Mlir-commits] [llvm] [mlir] [MLIR][IR] Add ConstantLikeInterface for extensible constant creation (PR #177740)
Ryan Kim
llvmlistbot at llvm.org
Sat Jan 24 01:03:19 PST 2026
https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/177740
>From 16e367ed96648450d69cc94b5632d4f15ef7773d Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Fri, 23 Jan 2026 08:53:45 +0900
Subject: [PATCH] [MLIR][IR] Add ConstantLikeInterface for extensible constant
creation
Introduces a new type interface `ConstantLikeInterface` that allows custom
types to define their own constant creation logic. This is particularly
useful for domain-specific types (e.g., finite field elements, custom
numeric types) that need special handling for constant attributes and ops.
Interface methods:
- `createConstantAttr(int64_t)`: Create a TypedAttr from a scalar value
- `createConstantAttrFromValues(ArrayRef<APInt>)`: Create a TypedAttr from
multiple APInt values (useful for composite types like extension fields)
- `createConstantOp(OpBuilder&, Location, TypedAttr)`: Create a dialect-
specific constant operation instead of relying on arith.constant
- `overrideShapedType(ShapedType)`: Override the shaped type used for
tensor/vector constants (e.g., to use a storage type for the elements)
Integration points:
- `Builder::getZeroAttr` and `Builder::getOneAttr` now check for this
interface first, enabling custom types to provide their own constant
creation logic before falling back to built-in type handling.
- ElementwiseOpFusion uses this interface to create scalar constants when
fusing splat operands into linalg.generic operations.
---
mlir/include/mlir/IR/CMakeLists.txt | 5 +
mlir/include/mlir/IR/ConstantLikeInterface.h | 34 ++
mlir/include/mlir/IR/ConstantLikeInterface.td | 82 +++
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 29 +-
mlir/lib/IR/Builders.cpp | 35 +-
mlir/lib/IR/CMakeLists.txt | 2 +
mlir/lib/IR/ConstantLikeInterface.cpp | 15 +
mlir/unittests/IR/CMakeLists.txt | 5 +-
.../IR/ConstantLikeInterfaceTest.cpp | 516 ++++++++++++++++++
utils/bazel/MODULE.bazel.lock | 5 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 19 +
.../mlir/unittests/BUILD.bazel | 1 +
12 files changed, 737 insertions(+), 11 deletions(-)
create mode 100644 mlir/include/mlir/IR/ConstantLikeInterface.h
create mode 100644 mlir/include/mlir/IR/ConstantLikeInterface.td
create mode 100644 mlir/lib/IR/ConstantLikeInterface.cpp
create mode 100644 mlir/unittests/IR/ConstantLikeInterfaceTest.cpp
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 0b3079cde568d..d7d1f116987a4 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -55,6 +55,11 @@ mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_mlir_generic_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
+set(LLVM_TARGET_DEFINITIONS ConstantLikeInterface.td)
+mlir_tablegen(ConstantLikeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(ConstantLikeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRConstantLikeInterfaceIncGen)
+
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/ConstantLikeInterface.h b/mlir/include/mlir/IR/ConstantLikeInterface.h
new file mode 100644
index 0000000000000..e8d4a2794d2c1
--- /dev/null
+++ b/mlir/include/mlir/IR/ConstantLikeInterface.h
@@ -0,0 +1,34 @@
+//===- ConstantLikeInterface.h - Constant Creation Interface -------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for type interfaces that allow types to
+// define how to create constant attributes and operations for themselves.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_CONSTANTLIKEINTERFACE_H
+#define MLIR_IR_CONSTANTLIKEINTERFACE_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Types.h"
+
+namespace llvm {
+class APInt;
+} // namespace llvm
+
+namespace mlir {
+class TypedAttr;
+class OpBuilder;
+class Location;
+class Operation;
+} // namespace mlir
+
+#include "mlir/IR/ConstantLikeInterface.h.inc"
+
+#endif // MLIR_IR_CONSTANTLIKEINTERFACE_H
diff --git a/mlir/include/mlir/IR/ConstantLikeInterface.td b/mlir/include/mlir/IR/ConstantLikeInterface.td
new file mode 100644
index 0000000000000..af5d948bc7260
--- /dev/null
+++ b/mlir/include/mlir/IR/ConstantLikeInterface.td
@@ -0,0 +1,82 @@
+//===- ConstantLikeInterface.td - Constant creation interfaces -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for type interfaces that allow types to
+// define how to create constant attributes and operations for themselves.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_CONSTANTLIKEINTERFACE_TD_
+#define MLIR_IR_CONSTANTLIKEINTERFACE_TD_
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ConstantLikeInterface
+//===----------------------------------------------------------------------===//
+
+def ConstantLikeInterface : TypeInterface<"ConstantLikeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface allows types to define how to create constant attributes
+ and operations for their values. This decouples generic MLIR code from
+ specific constant operation types, enabling better layering and extensibility.
+
+ Types implementing this interface can provide custom constant creation logic,
+ which is particularly useful for domain-specific types (e.g., field elements,
+ custom numeric types) that need special constant handling.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates a constant attribute for this type from the given int64 value.
+ Returns null if the type does not support this operation.
+ }],
+ /*retTy=*/"::mlir::TypedAttr",
+ /*methodName=*/"createConstantAttr",
+ /*args=*/(ins "int64_t ":$value)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates a constant attribute for this type from the given APInt values.
+ The APInt should be compatible with this type's bit width and semantics.
+ Returns null if the type does not support this operation.
+ }],
+ /*retTy=*/"::mlir::TypedAttr",
+ /*methodName=*/"createConstantAttrFromValues",
+ /*args=*/(ins "::llvm::ArrayRef<APInt> ":$values)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates a constant operation for this type with the given attribute.
+ The builder's insertion point should be set before calling this method.
+ Returns null if the type does not support constant operation creation.
+ }],
+ /*retTy=*/"::mlir::Operation *",
+ /*methodName=*/"createConstantOp",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::Location":$loc,
+ "::mlir::TypedAttr":$attr)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Allows custom types to override the shaped type (tensor/vector) used for
+ constant creation. This is useful when a custom element type needs to
+ substitute the shaped type with a different representation (e.g., using
+ a storage type instead of the semantic type).
+ Returns the potentially modified shaped type.
+ }],
+ /*retTy=*/"::mlir::ShapedType",
+ /*methodName=*/"overrideShapedType",
+ /*args=*/(ins "::mlir::ShapedType":$shapedType)
+ >
+ ];
+}
+
+#endif // MLIR_IR_CONSTANTLIKEINTERFACE_TD_
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 72acd02d0d13d..cff08755052ca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/ConstantLikeInterface.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
@@ -2307,8 +2308,27 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
}
// Create a constant scalar value from the splat constant.
- Value scalarConstant =
- arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+ Value scalarConstant;
+ Region ®ion = genericOp->getRegion(0);
+ Block &entryBlock = *region.begin();
+ Value argument = entryBlock.getArgument(opOperand->getOperandNumber());
+ Type argType = argument.getType();
+
+ // Try to use ConstantLikeInterface for custom constant creation
+ if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(argType)) {
+ if (Operation *constOp = constType.createConstantOp(
+ rewriter, def->getLoc(), constantAttr)) {
+ scalarConstant = constOp->getResult(0);
+ } else {
+ // Fallback to arith::ConstantOp
+ scalarConstant =
+ arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+ }
+ } else {
+ // Default to arith::ConstantOp for types without the interface
+ scalarConstant =
+ arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+ }
SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp =
@@ -2323,11 +2343,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
- Region ®ion = genericOp->getRegion(0);
- Block &entryBlock = *region.begin();
IRMapping mapping;
- mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
- scalarConstant);
+ mapping.map(argument, scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
mapping);
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8f199b60fccdc..d7b2d96c71a3c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ConstantLikeInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -322,6 +323,14 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
}
TypedAttr Builder::getZeroAttr(Type type) {
+ // Check if the type implements ConstantLikeInterface for custom constant
+ // creation
+ if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(type)) {
+ if (auto attr = constType.createConstantAttr(0))
+ return attr;
+ }
+
+ // Fallback to built-in type handling
if (llvm::isa<FloatType>(type))
return getFloatAttr(type, 0.0);
if (llvm::isa<IndexType>(type))
@@ -331,15 +340,30 @@ TypedAttr Builder::getZeroAttr(Type type) {
APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
if (llvm::isa<RankedTensorType, VectorType>(type)) {
auto vtType = llvm::cast<ShapedType>(type);
- auto element = getZeroAttr(vtType.getElementType());
+ auto elementType = vtType.getElementType();
+
+ auto element = getZeroAttr(elementType);
if (!element)
return {};
+ // Check if element type implements ConstantLikeInterface for custom
+ // shaped constant creation
+ if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(elementType)) {
+ vtType = constType.overrideShapedType(vtType);
+ }
return DenseElementsAttr::get(vtType, element);
}
return {};
}
TypedAttr Builder::getOneAttr(Type type) {
+ // Check if the type implements ConstantLikeInterface for custom constant
+ // creation
+ if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(type)) {
+ if (auto attr = constType.createConstantAttr(1))
+ return attr;
+ }
+
+ // Fallback to built-in type handling
if (llvm::isa<FloatType>(type))
return getFloatAttr(type, 1.0);
if (llvm::isa<IndexType>(type))
@@ -349,9 +373,16 @@ TypedAttr Builder::getOneAttr(Type type) {
APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
if (llvm::isa<RankedTensorType, VectorType>(type)) {
auto vtType = llvm::cast<ShapedType>(type);
- auto element = getOneAttr(vtType.getElementType());
+ auto elementType = vtType.getElementType();
+
+ auto element = getOneAttr(elementType);
if (!element)
return {};
+ // Check if element type implements ConstantLikeInterface for custom
+ // shaped constant creation
+ if (auto constType = llvm::dyn_cast<ConstantLikeInterface>(elementType)) {
+ vtType = constType.overrideShapedType(vtType);
+ }
return DenseElementsAttr::get(vtType, element);
}
return {};
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index d95bdc957e3c2..cf12cb5e53299 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -18,6 +18,7 @@ add_mlir_library(MLIRIR
BuiltinDialectBytecode.cpp
BuiltinTypes.cpp
BuiltinTypeInterfaces.cpp
+ ConstantLikeInterface.cpp
Diagnostics.cpp
Dialect.cpp
DialectResourceBlobManager.cpp
@@ -61,6 +62,7 @@ add_mlir_library(MLIRIR
MLIRBuiltinTypeInterfacesIncGen
MLIRCallInterfacesIncGen
MLIRCastInterfacesIncGen
+ MLIRConstantLikeInterfaceIncGen
MLIRDataLayoutInterfacesIncGen
MLIROpAsmInterfaceIncGen
MLIRRegionKindInterfaceIncGen
diff --git a/mlir/lib/IR/ConstantLikeInterface.cpp b/mlir/lib/IR/ConstantLikeInterface.cpp
new file mode 100644
index 0000000000000..028434eea5771
--- /dev/null
+++ b/mlir/lib/IR/ConstantLikeInterface.cpp
@@ -0,0 +1,15 @@
+//===- ConstantLikeInterface.cpp - Constant Creation Interface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ConstantLikeInterface.h"
+
+namespace mlir {
+
+#include "mlir/IR/ConstantLikeInterface.cpp.inc"
+
+} // namespace mlir
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index dd3b110dcd295..af04c8b5fd9ed 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_unittest(MLIRIRTests
AffineMapTest.cpp
AttributeTest.cpp
AttrTypeReplacerTest.cpp
+ ConstantLikeInterfaceTest.cpp
Diagnostic.cpp
DialectTest.cpp
DistinctAttributeAllocatorTest.cpp
@@ -14,7 +15,7 @@ add_mlir_unittest(MLIRIRTests
MemrefLayoutTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
- RemarkTest.cpp
+ RemarkTest.cpp
ShapedTypeTest.cpp
SymbolTableTest.cpp
TypeTest.cpp
@@ -24,9 +25,11 @@ add_mlir_unittest(MLIRIRTests
BlobManagerTest.cpp
DEPENDS
+ MLIRConstantLikeInterfaceIncGen
MLIRTestInterfaceIncGen
)
target_include_directories(MLIRIRTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
mlir_target_link_libraries(MLIRIRTests PRIVATE MLIRIR)
+target_link_libraries(MLIRIRTests PRIVATE MLIRArithDialect)
target_link_libraries(MLIRIRTests PRIVATE MLIRTestDialect)
target_link_libraries(MLIRIRTests PRIVATE MLIRRemarkStreamer)
diff --git a/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp b/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp
new file mode 100644
index 0000000000000..d956e1fdb5e2d
--- /dev/null
+++ b/mlir/unittests/IR/ConstantLikeInterfaceTest.cpp
@@ -0,0 +1,516 @@
+//===- ConstantLikeInterfaceTest.cpp - Test ConstantLikeInterface ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements tests for the ConstantLikeInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ConstantLikeInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test External Models
+//===----------------------------------------------------------------------===//
+
+/// External interface model for IntegerType that provides custom constant
+/// creation. This model multiplies values by 2 to demonstrate the interface
+/// is being used.
+struct TestConstantLikeIntegerModel
+ : public ConstantLikeInterface::ExternalModel<TestConstantLikeIntegerModel,
+ IntegerType> {
+ /// Creates an IntegerAttr with the value multiplied by 2 (for testing).
+ TypedAttr createConstantAttr(Type type, int64_t value) const {
+ auto intType = llvm::cast<IntegerType>(type);
+ // Multiply by 2 to verify this method is called instead of the default.
+ return IntegerAttr::get(intType, value * 2);
+ }
+
+ /// Creates an IntegerAttr from APInt values.
+ /// For IntegerType, we expect exactly one APInt value and add 10 to it.
+ TypedAttr createConstantAttrFromValues(Type type,
+ ArrayRef<APInt> values) const {
+ if (values.size() != 1)
+ return {};
+ auto intType = llvm::cast<IntegerType>(type);
+ // Add 10 to verify this method is called.
+ APInt result = values[0] + 10;
+ return IntegerAttr::get(intType, result);
+ }
+
+ /// Creates an arith.constant operation for the integer type.
+ Operation *createConstantOp(Type type, OpBuilder &builder, Location loc,
+ TypedAttr attr) const {
+ return arith::ConstantOp::create(builder, loc, attr);
+ }
+
+ /// Overrides the shaped type by changing the element type to i64.
+ /// This simulates types that need a different storage representation.
+ ShapedType overrideShapedType(Type type, ShapedType shapedType) const {
+ // Override to use i64 as storage type for demonstration.
+ IntegerType storageType = IntegerType::get(type.getContext(), 64);
+ return shapedType.clone(storageType);
+ }
+};
+
+/// External interface model for Float32Type with distinct behavior.
+struct TestConstantLikeFloatModel
+ : public ConstantLikeInterface::ExternalModel<TestConstantLikeFloatModel,
+ Float32Type> {
+ /// Creates a FloatAttr with the value plus 100 (for testing).
+ TypedAttr createConstantAttr(Type type, int64_t value) const {
+ // Add 100 to verify this method is called.
+ return FloatAttr::get(type, static_cast<double>(value) + 100.0);
+ }
+
+ TypedAttr createConstantAttrFromValues(Type type,
+ ArrayRef<APInt> values) const {
+ if (values.size() != 1)
+ return {};
+ // Convert APInt to double and add 50.
+ double d = static_cast<double>(values[0].getSExtValue()) + 50.0;
+ return FloatAttr::get(type, d);
+ }
+
+ Operation *createConstantOp(Type type, OpBuilder &builder, Location loc,
+ TypedAttr attr) const {
+ return arith::ConstantOp::create(builder, loc, attr);
+ }
+
+ ShapedType overrideShapedType(Type type, ShapedType shapedType) const {
+ // Keep unchanged for float.
+ return shapedType;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Basic Interface Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, InterfaceNotAttached) {
+ // Without attaching the interface, types should not implement it.
+ MLIRContext context;
+ IntegerType i32 = IntegerType::get(&context, 32);
+ EXPECT_FALSE(isa<ConstantLikeInterface>(i32));
+
+ Float32Type f32 = Float32Type::get(&context);
+ EXPECT_FALSE(isa<ConstantLikeInterface>(f32));
+}
+
+TEST(ConstantLikeInterfaceTest, AttachInterfaceToIntegerType) {
+ MLIRContext context;
+
+ // Attach the interface to IntegerType.
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = dyn_cast<ConstantLikeInterface>(i32);
+ ASSERT_TRUE(iface != nullptr);
+
+ // Test createConstantAttr - our model multiplies by 2.
+ TypedAttr attr = iface.createConstantAttr(5);
+ ASSERT_TRUE(attr != nullptr);
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
+ ASSERT_TRUE(intAttr != nullptr);
+ EXPECT_EQ(intAttr.getInt(), 10); // 5 * 2 = 10
+}
+
+TEST(ConstantLikeInterfaceTest, AttachInterfaceToFloatType) {
+ MLIRContext context;
+
+ // Attach the interface to Float32Type.
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ Float32Type f32 = Float32Type::get(&context);
+ auto iface = dyn_cast<ConstantLikeInterface>(f32);
+ ASSERT_TRUE(iface != nullptr);
+
+ // Test createConstantAttr - our model adds 100.
+ TypedAttr attr = iface.createConstantAttr(1);
+ ASSERT_TRUE(attr != nullptr);
+ auto floatAttr = dyn_cast<FloatAttr>(attr);
+ ASSERT_TRUE(floatAttr != nullptr);
+ EXPECT_DOUBLE_EQ(floatAttr.getValueAsDouble(), 101.0); // 1 + 100 = 101
+}
+
+TEST(ConstantLikeInterfaceTest, InterfaceNotSharedAcrossContexts) {
+ MLIRContext context1;
+ MLIRContext context2;
+
+ // Attach interface only to context1.
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context1);
+
+ IntegerType i32Ctx1 = IntegerType::get(&context1, 32);
+ IntegerType i32Ctx2 = IntegerType::get(&context2, 32);
+
+ // context1 should have the interface.
+ EXPECT_TRUE(isa<ConstantLikeInterface>(i32Ctx1));
+
+ // context2 should NOT have the interface.
+ EXPECT_FALSE(isa<ConstantLikeInterface>(i32Ctx2));
+}
+
+//===----------------------------------------------------------------------===//
+// createConstantAttrFromValues Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValues) {
+ MLIRContext context;
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = cast<ConstantLikeInterface>(i32);
+
+ // Test createConstantAttrFromValues with a single APInt.
+ APInt value(32, 7); // 7
+ SmallVector<APInt> values = {value};
+ TypedAttr attr = iface.createConstantAttrFromValues(values);
+ ASSERT_TRUE(attr != nullptr);
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
+ ASSERT_TRUE(intAttr != nullptr);
+ EXPECT_EQ(intAttr.getInt(), 17); // 7 + 10 = 17
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValuesFloat) {
+ MLIRContext context;
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ Float32Type f32 = Float32Type::get(&context);
+ auto iface = cast<ConstantLikeInterface>(f32);
+
+ // Test createConstantAttrFromValues for float.
+ APInt value(32, 25);
+ SmallVector<APInt> values = {value};
+ TypedAttr attr = iface.createConstantAttrFromValues(values);
+ ASSERT_TRUE(attr != nullptr);
+ auto floatAttr = dyn_cast<FloatAttr>(attr);
+ ASSERT_TRUE(floatAttr != nullptr);
+ EXPECT_DOUBLE_EQ(floatAttr.getValueAsDouble(), 75.0); // 25 + 50 = 75
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantAttrFromValuesInvalidSize) {
+ MLIRContext context;
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = cast<ConstantLikeInterface>(i32);
+
+ // Test with wrong number of values (our model expects exactly 1).
+ APInt v1(32, 1), v2(32, 2);
+ SmallVector<APInt> values = {v1, v2};
+ TypedAttr attr = iface.createConstantAttrFromValues(values);
+ EXPECT_TRUE(attr == nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// createConstantOp Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, CreateConstantOp) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ // Create a module to hold operations.
+ OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module->getBody(), module->getBody()->begin());
+ Location loc = builder.getUnknownLoc();
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = cast<ConstantLikeInterface>(i32);
+
+ // Create an attribute and then a constant operation.
+ TypedAttr attr = iface.createConstantAttr(21);
+ ASSERT_TRUE(attr != nullptr);
+
+ Operation *constOp = iface.createConstantOp(builder, loc, attr);
+ ASSERT_TRUE(constOp != nullptr);
+
+ // Verify it's an arith.constant.
+ EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+ auto arithConst = cast<arith::ConstantOp>(constOp);
+ auto resultAttr = dyn_cast<IntegerAttr>(arithConst.getValue());
+ ASSERT_TRUE(resultAttr != nullptr);
+ EXPECT_EQ(resultAttr.getInt(), 42); // 21 * 2 = 42
+}
+
+TEST(ConstantLikeInterfaceTest, CreateConstantOpFloat) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module->getBody(), module->getBody()->begin());
+ Location loc = builder.getUnknownLoc();
+
+ Float32Type f32 = Float32Type::get(&context);
+ auto iface = cast<ConstantLikeInterface>(f32);
+
+ TypedAttr attr = iface.createConstantAttr(50);
+ Operation *constOp = iface.createConstantOp(builder, loc, attr);
+ ASSERT_TRUE(constOp != nullptr);
+
+ EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+ auto arithConst = cast<arith::ConstantOp>(constOp);
+ auto resultAttr = dyn_cast<FloatAttr>(arithConst.getValue());
+ ASSERT_TRUE(resultAttr != nullptr);
+ EXPECT_DOUBLE_EQ(resultAttr.getValueAsDouble(), 150.0); // 50 + 100 = 150
+}
+
+//===----------------------------------------------------------------------===//
+// overrideShapedType Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedType) {
+ MLIRContext context;
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = cast<ConstantLikeInterface>(i32);
+
+ // Create a tensor type with i32 elements.
+ RankedTensorType tensorType = RankedTensorType::get({2, 3}, i32);
+
+ // Override should change element type to i64.
+ ShapedType overridden = iface.overrideShapedType(tensorType);
+ ASSERT_TRUE(overridden != nullptr);
+
+ // Shape should be preserved.
+ EXPECT_EQ(overridden.getShape(), tensorType.getShape());
+
+ // Element type should now be i64.
+ auto elementType = dyn_cast<IntegerType>(overridden.getElementType());
+ ASSERT_TRUE(elementType != nullptr);
+ EXPECT_EQ(elementType.getWidth(), 64u);
+}
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedTypePreservesShape) {
+ MLIRContext context;
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ IntegerType i16 = IntegerType::get(&context, 16);
+ auto iface = cast<ConstantLikeInterface>(i16);
+
+ // Test with various shapes.
+ RankedTensorType tensor1D = RankedTensorType::get({10}, i16);
+ RankedTensorType tensor3D = RankedTensorType::get({2, 3, 4}, i16);
+
+ ShapedType overridden1D = iface.overrideShapedType(tensor1D);
+ ShapedType overridden3D = iface.overrideShapedType(tensor3D);
+
+ EXPECT_EQ(overridden1D.getShape(), tensor1D.getShape());
+ EXPECT_EQ(overridden3D.getShape(), tensor3D.getShape());
+
+ // Both should have i64 element type.
+ EXPECT_EQ(cast<IntegerType>(overridden1D.getElementType()).getWidth(), 64u);
+ EXPECT_EQ(cast<IntegerType>(overridden3D.getElementType()).getWidth(), 64u);
+}
+
+TEST(ConstantLikeInterfaceTest, OverrideShapedTypeFloat) {
+ MLIRContext context;
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ Float32Type f32 = Float32Type::get(&context);
+ auto iface = cast<ConstantLikeInterface>(f32);
+
+ // Create a tensor type with f32 elements.
+ RankedTensorType tensorType = RankedTensorType::get({4, 5}, f32);
+
+ // Float model keeps type unchanged.
+ ShapedType overridden = iface.overrideShapedType(tensorType);
+ EXPECT_EQ(overridden, tensorType);
+}
+
+//===----------------------------------------------------------------------===//
+// Builder Integration Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, BuilderGetZeroAttrUsesInterface) {
+ MLIRContext context;
+
+ // First, test without interface - should use default behavior.
+ Builder builderWithoutInterface(&context);
+ IntegerType i32 = IntegerType::get(&context, 32);
+
+ TypedAttr defaultZero = builderWithoutInterface.getZeroAttr(i32);
+ ASSERT_TRUE(defaultZero != nullptr);
+ auto defaultIntAttr = dyn_cast<IntegerAttr>(defaultZero);
+ ASSERT_TRUE(defaultIntAttr != nullptr);
+ EXPECT_EQ(defaultIntAttr.getInt(), 0); // Default: just 0
+
+ // Now attach the interface.
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ // After attaching, getZeroAttr should use our interface.
+ Builder builderWithInterface(&context);
+ TypedAttr customZero = builderWithInterface.getZeroAttr(i32);
+ ASSERT_TRUE(customZero != nullptr);
+ auto customIntAttr = dyn_cast<IntegerAttr>(customZero);
+ ASSERT_TRUE(customIntAttr != nullptr);
+ EXPECT_EQ(customIntAttr.getInt(),
+ 0); // 0 * 2 = 0 (still 0, but via interface)
+}
+
+TEST(ConstantLikeInterfaceTest, BuilderGetOneAttrUsesInterface) {
+ MLIRContext context;
+
+ // Attach the interface to IntegerType.
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ Builder builder(&context);
+ IntegerType i32 = IntegerType::get(&context, 32);
+
+ // getOneAttr should use our interface which multiplies by 2.
+ TypedAttr customOne = builder.getOneAttr(i32);
+ ASSERT_TRUE(customOne != nullptr);
+ auto customIntAttr = dyn_cast<IntegerAttr>(customOne);
+ ASSERT_TRUE(customIntAttr != nullptr);
+ EXPECT_EQ(customIntAttr.getInt(), 2); // 1 * 2 = 2
+}
+
+TEST(ConstantLikeInterfaceTest, ShapedTypeWithInterface) {
+ MLIRContext context;
+
+ // Use Float32Type because the float model doesn't override shaped type,
+ // which avoids element type mismatch in DenseElementsAttr creation.
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ Builder builder(&context);
+ Float32Type f32 = Float32Type::get(&context);
+ RankedTensorType tensorType = RankedTensorType::get({2, 3}, f32);
+
+ // getZeroAttr for tensor should create DenseElementsAttr.
+ TypedAttr zeroTensor = builder.getZeroAttr(tensorType);
+ ASSERT_TRUE(zeroTensor != nullptr);
+ auto denseAttr = dyn_cast<DenseElementsAttr>(zeroTensor);
+ ASSERT_TRUE(denseAttr != nullptr);
+
+ // Check that all elements are 100.0 (0 + 100 = 100 from our float model).
+ for (auto val : denseAttr.getValues<APFloat>()) {
+ EXPECT_DOUBLE_EQ(val.convertToDouble(), 100.0);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Delayed Registration Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, DelayedRegistration) {
+ DialectRegistry registry;
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(*ctx);
+ });
+
+ MLIRContext context(registry);
+ IntegerType i32 = IntegerType::get(&context, 32);
+
+ // Interface should be available.
+ EXPECT_TRUE(isa<ConstantLikeInterface>(i32));
+
+ // Test that the interface works correctly.
+ auto iface = cast<ConstantLikeInterface>(i32);
+ TypedAttr attr = iface.createConstantAttr(7);
+ auto intAttr = cast<IntegerAttr>(attr);
+ EXPECT_EQ(intAttr.getInt(), 14); // 7 * 2 = 14
+}
+
+//===----------------------------------------------------------------------===//
+// Combined Workflow Tests
+//===----------------------------------------------------------------------===//
+
+TEST(ConstantLikeInterfaceTest, FullWorkflowWithAllMethods) {
+ // Test a complete workflow using all interface methods.
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+
+ OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module->getBody(), module->getBody()->begin());
+ Location loc = builder.getUnknownLoc();
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ auto iface = cast<ConstantLikeInterface>(i32);
+
+ // Step 1: Create attr from int64_t.
+ TypedAttr attrFromInt = iface.createConstantAttr(5);
+ EXPECT_EQ(cast<IntegerAttr>(attrFromInt).getInt(), 10); // 5 * 2
+
+ // Step 2: Create attr from APInt values.
+ APInt apVal(32, 20);
+ TypedAttr attrFromAPInt = iface.createConstantAttrFromValues({apVal});
+ EXPECT_EQ(cast<IntegerAttr>(attrFromAPInt).getInt(), 30); // 20 + 10
+
+ // Step 3: Override shaped type.
+ RankedTensorType tensorType = RankedTensorType::get({4}, i32);
+ ShapedType overridden = iface.overrideShapedType(tensorType);
+ EXPECT_EQ(cast<IntegerType>(overridden.getElementType()).getWidth(), 64u);
+
+ // Step 4: Create constant op.
+ Operation *constOp = iface.createConstantOp(builder, loc, attrFromInt);
+ ASSERT_TRUE(constOp != nullptr);
+ EXPECT_TRUE(isa<arith::ConstantOp>(constOp));
+
+ // Verify the operation is in the module.
+ EXPECT_FALSE(module->getBody()->empty());
+}
+
+TEST(ConstantLikeInterfaceTest, MultipleTypesWithDifferentBehavior) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ IntegerType::attachInterface<TestConstantLikeIntegerModel>(context);
+ Float32Type::attachInterface<TestConstantLikeFloatModel>(context);
+
+ OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module->getBody(), module->getBody()->begin());
+ Location loc = builder.getUnknownLoc();
+
+ IntegerType i32 = IntegerType::get(&context, 32);
+ Float32Type f32 = Float32Type::get(&context);
+
+ auto intIface = cast<ConstantLikeInterface>(i32);
+ auto floatIface = cast<ConstantLikeInterface>(f32);
+
+ // Both types should produce different results for the same input.
+ TypedAttr intAttr = intIface.createConstantAttr(10);
+ TypedAttr floatAttr = floatIface.createConstantAttr(10);
+
+ EXPECT_EQ(cast<IntegerAttr>(intAttr).getInt(), 20); // 10 * 2
+ EXPECT_DOUBLE_EQ(cast<FloatAttr>(floatAttr).getValueAsDouble(),
+ 110.0); // 10 + 100
+
+ // Create constant ops for both.
+ Operation *intConstOp = intIface.createConstantOp(builder, loc, intAttr);
+ Operation *floatConstOp =
+ floatIface.createConstantOp(builder, loc, floatAttr);
+
+ ASSERT_TRUE(intConstOp != nullptr);
+ ASSERT_TRUE(floatConstOp != nullptr);
+
+ // Both should be arith.constant but with different types.
+ EXPECT_TRUE(isa<arith::ConstantOp>(intConstOp));
+ EXPECT_TRUE(isa<arith::ConstantOp>(floatConstOp));
+
+ auto intResult = cast<arith::ConstantOp>(intConstOp).getResult().getType();
+ auto floatResult =
+ cast<arith::ConstantOp>(floatConstOp).getResult().getType();
+
+ EXPECT_TRUE(isa<IntegerType>(intResult));
+ EXPECT_TRUE(isa<Float32Type>(floatResult));
+}
+
+} // namespace
diff --git a/utils/bazel/MODULE.bazel.lock b/utils/bazel/MODULE.bazel.lock
index 7070fd2fa9c61..b01e8474f70e0 100644
--- a/utils/bazel/MODULE.bazel.lock
+++ b/utils/bazel/MODULE.bazel.lock
@@ -234,7 +234,7 @@
"moduleExtensions": {
"//:extensions.bzl%llvm_repos_extension": {
"general": {
- "bzlTransitiveDigest": "9jGazpNxASw0pQCCKAMsxGYnVBJH8Mkddp3w7yRm6eU=",
+ "bzlTransitiveDigest": "DqyjXvFpPyfsmyt58EuR0v0Xy+nPN74iTGxjOW99OcQ=",
"usagesDigest": "X0yUkkWyxQ2Y5oZVDkRSE/K4YkDWo1IjhHsL+1weKyU=",
"recordedFileInputs": {},
"recordedDirentsInputs": {},
@@ -300,7 +300,8 @@
"repoRuleId": "@@bazel_tools//tools/build_defs/repo:http.bzl%http_archive",
"attributes": {
"urls": [
- "https://versaweb.dl.sourceforge.net/project/perfmon2/libpfm4/libpfm-4.13.0.tar.gz"
+ "https://versaweb.dl.sourceforge.net/project/perfmon2/libpfm4/libpfm-4.13.0.tar.gz",
+ "https://sourceforge.net/projects/perfmon2/files/libpfm4/libpfm-4.13.0.tar.gz"
],
"sha256": "d18b97764c755528c1051d376e33545d0eb60c6ebf85680436813fa5b04cc3d1",
"strip_prefix": "libpfm-4.13.0",
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 73adfa40d831f..b10caecc749f6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -113,6 +113,17 @@ gentbl_cc_library(
deps = [":OpBaseTdFiles"],
)
+gentbl_cc_library(
+ name = "ConstantLikeInterfaceIncGen",
+ tbl_outs = {
+ "include/mlir/IR/ConstantLikeInterface.h.inc": ["-gen-type-interface-decls"],
+ "include/mlir/IR/ConstantLikeInterface.cpp.inc": ["-gen-type-interface-defs"],
+ },
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/IR/ConstantLikeInterface.td",
+ deps = [":OpBaseTdFiles"],
+)
+
gentbl_cc_library(
name = "TensorEncodingIncGen",
tbl_outs = {
@@ -250,6 +261,13 @@ gentbl_cc_library(
],
)
+td_library(
+ name = "ConstantLikeInterfaceTdFiles",
+ srcs = ["include/mlir/IR/ConstantLikeInterface.td"],
+ includes = ["include"],
+ deps = [":OpBaseTdFiles"],
+)
+
td_library(
name = "FunctionInterfacesTdFiles",
srcs = ["include/mlir/Interfaces/FunctionInterfaces.td"],
@@ -421,6 +439,7 @@ cc_library(
":BuiltinTypesIncGen",
":BytecodeOpInterfaceIncGen",
":CallOpInterfacesIncGen",
+ ":ConstantLikeInterfaceIncGen",
":DataLayoutInterfacesIncGen",
":InferIntRangeInterfaceIncGen",
":OpAsmInterfaceIncGen",
diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
index dfb5be5f85027..37277fdcb0699 100644
--- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
@@ -40,6 +40,7 @@ cc_test(
"//llvm:Core",
"//llvm:Remarks",
"//llvm:Support",
+ "//mlir:ArithDialect",
"//mlir:BytecodeReader",
"//mlir:CallOpInterfaces",
"//mlir:FunctionInterfaces",
More information about the Mlir-commits
mailing list